diff --git a/VERSION b/VERSION index 1fe695856..24ff85581 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.0.28 +0.0.27 diff --git a/docker/.env.example b/docker/.env.example index 54ca489b2..cafc74af9 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -166,26 +166,25 @@ EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 # REDIS_URL=redis://redis:6379/0 # ------------------------------------------------------------------------------ -# Stripe (unified credit wallet, disabled by default) +# Stripe (pay-as-you-go page packs, disabled by default) # ------------------------------------------------------------------------------ -# Set TRUE to allow users to buy credit packs via Stripe Checkout. $1 buys -# 1_000_000 micro-USD of credit; both ETL page processing and premium turns -# debit this balance at the actual per-call provider cost from LiteLLM. -STRIPE_CREDIT_BUYING_ENABLED=FALSE +# Set TRUE to allow users to buy additional page packs via Stripe Checkout +STRIPE_PAGE_BUYING_ENABLED=FALSE # STRIPE_SECRET_KEY=sk_test_... # STRIPE_WEBHOOK_SECRET=whsec_... -# STRIPE_CREDIT_PRICE_ID=price_... -# STRIPE_CREDIT_MICROS_PER_UNIT=1000000 +# STRIPE_PRICE_ID=price_... +# STRIPE_PAGES_PER_UNIT=1000 # STRIPE_RECONCILIATION_INTERVAL=10m # STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10 # STRIPE_RECONCILIATION_BATCH_SIZE=100 -# Auto-reload: top up via a saved Stripe card when the balance drops below -# the user-chosen threshold. Off by default. -# AUTO_RELOAD_ENABLED=FALSE -# AUTO_RELOAD_MIN_AMOUNT_MICROS=1000000 -# AUTO_RELOAD_COOLDOWN_MINUTES=10 +# Premium credit purchases via Stripe ($1 buys 1_000_000 micro-USD of +# credit; premium turns debit the actual per-call provider cost +# reported by LiteLLM, so cheap and expensive models bill proportionally) +# STRIPE_TOKEN_BUYING_ENABLED=FALSE +# STRIPE_PREMIUM_TOKEN_PRICE_ID=price_... +# STRIPE_CREDIT_MICROS_PER_UNIT=1000000 # ------------------------------------------------------------------------------ # TTS & STT (Text-to-Speech / Speech-to-Text) @@ -408,16 +407,13 @@ SURFSENSE_ENABLE_DOOM_LOOP=true # ACCESS_TOKEN_LIFETIME_SECONDS=86400 # REFRESH_TOKEN_LIFETIME_SECONDS=1209600 -# Unified credit wallet starting balance for new users, in micro-USD -# (default: $5). Funds both ETL page processing and premium model calls, -# debited at the actual per-call provider cost reported by LiteLLM. -# DEFAULT_CREDIT_MICROS_BALANCE=5000000 +# Pages limit per user for ETL (default: unlimited) +# PAGES_LIMIT=500 -# Debit the credit wallet for ETL page processing. Default FALSE keeps ETL -# effectively free for self-hosted installs. 1 page == MICROS_PER_PAGE -# micro-USD ($0.001); premium ETL mode is 10x. -# ETL_CREDIT_BILLING_ENABLED=FALSE -# MICROS_PER_PAGE=1000 +# Premium credit quota per registered user, in micro-USD (default: $5). +# Premium turns are debited at the actual per-call provider cost reported +# by LiteLLM. Only applies to models with billing_tier=premium. +# PREMIUM_CREDIT_MICROS_LIMIT=5000000 # Safety ceiling on per-call premium reservation, in micro-USD ($1.00 default). # QUOTA_MAX_RESERVE_MICROS=1000000 diff --git a/docker/scripts/install.ps1 b/docker/scripts/install.ps1 index 23b14c3c4..6e973a520 100644 --- a/docker/scripts/install.ps1 +++ b/docker/scripts/install.ps1 @@ -17,14 +17,10 @@ # into the new PostgreSQL 17 stack. The user runs one command for both. # ============================================================================= -# NOTE: Do not use [ValidateSet()] (or other validation attributes without a -# valid default) on these params. When the script is piped into iex, PowerShell -# applies the attributes to variables in the caller's scope, and an empty -# $Variant fails ValidateSet with a ValidationMetadataException. Validate -# manually below instead. param( [switch]$NoWatchtower, [int]$WatchtowerInterval = 86400, + [ValidateSet("cpu", "cuda", "cuda126")] [string]$Variant, [string]$GpuCount, [switch]$Quiet @@ -44,11 +40,6 @@ $MigrationMode = $false $SetupWatchtower = -not $NoWatchtower $WatchtowerContainer = "watchtower" -if ($Variant -and $Variant -notin @("cpu", "cuda", "cuda126")) { - Write-Host "[SurfSense] ERROR: Invalid -Variant '$Variant'. Use 'cpu', 'cuda', or 'cuda126'." -ForegroundColor Red - exit 1 -} - if ($GpuCount -and $GpuCount -notmatch '^([0-9]+|all)$') { Write-Host "[SurfSense] ERROR: Invalid -GpuCount '$GpuCount'. Use a number or 'all'." -ForegroundColor Red exit 1 diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index b4f67328c..6e49a7132 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -75,16 +75,23 @@ SECRET_KEY=SECRET NEXT_FRONTEND_URL=http://localhost:3000 -# Stripe Checkout for the unified credit wallet. -# Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit -# (default 1_000_000 = $1.00). Both ETL page processing and premium model -# turns are billed against this single balance at actual provider cost. +# Stripe Checkout for pay-as-you-go page packs +# Configure STRIPE_PRICE_ID to point at your 1,000-page price in Stripe. +# Pages granted per purchase = quantity * STRIPE_PAGES_PER_UNIT. STRIPE_SECRET_KEY=sk_test_... STRIPE_WEBHOOK_SECRET=whsec_... -STRIPE_CREDIT_PRICE_ID=price_... -STRIPE_CREDIT_MICROS_PER_UNIT=1000000 +STRIPE_PRICE_ID=price_... +STRIPE_PAGES_PER_UNIT=1000 # Set FALSE to disable new checkout session creation temporarily -STRIPE_CREDIT_BUYING_ENABLED=FALSE +STRIPE_PAGE_BUYING_ENABLED=TRUE + +# Premium credit purchases via Stripe (for premium-tier model usage). +# Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit +# (default 1_000_000 = $1.00). Premium turns are billed at the actual +# per-call provider cost reported by LiteLLM. +STRIPE_TOKEN_BUYING_ENABLED=FALSE +STRIPE_PREMIUM_TOKEN_PRICE_ID=price_... +STRIPE_CREDIT_MICROS_PER_UNIT=1000000 # Periodic Stripe safety net for purchases left in PENDING (minutes old) STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10 @@ -214,25 +221,15 @@ VIDEO_PRESENTATION_FPS=30 VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES=300 -# Unified credit wallet starting balance for new users, in micro-USD -# (default: 5,000,000 == $5.00). The same balance funds ETL page processing -# and premium model calls, debited at actual provider cost. -DEFAULT_CREDIT_MICROS_BALANCE=5000000 +# (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version) +PAGES_LIMIT=500 -# Debit the credit wallet for ETL page processing. Default FALSE keeps ETL -# effectively free for self-hosted/OSS installs; hosted deployments set TRUE. -# 1 page == MICROS_PER_PAGE micro-USD ($0.001); premium ETL mode is 10x. -ETL_CREDIT_BILLING_ENABLED=FALSE -MICROS_PER_PAGE=1000 - -# Low-balance warning threshold (micro-USD), surfaced to the UI. Default $0.50. -CREDIT_LOW_BALANCE_WARNING_MICROS=500000 - -# Auto-reload: automatically top up via a saved Stripe card when the balance -# drops below the user-chosen threshold. Off by default. -AUTO_RELOAD_ENABLED=FALSE -AUTO_RELOAD_MIN_AMOUNT_MICROS=1000000 -AUTO_RELOAD_COOLDOWN_MINUTES=10 +# Premium credit quota per registered user, in micro-USD +# (default: 5,000,000 == $5.00 of credit). Premium turns are debited at the +# actual per-call provider cost reported by LiteLLM, so cheap and expensive +# models bill proportionally. Applies only to models with +# billing_tier=premium in global_llm_config.yaml. +PREMIUM_CREDIT_MICROS_LIMIT=5000000 # Safety ceiling on per-call premium reservation, in micro-USD. # stream_new_chat estimates an upper-bound cost from the model's diff --git a/surfsense_backend/.gitignore b/surfsense_backend/.gitignore index bda5961fe..efc6c90d7 100644 --- a/surfsense_backend/.gitignore +++ b/surfsense_backend/.gitignore @@ -1,12 +1,12 @@ .env .venv venv/ -/data/ +data/ .local_object_store/ __pycache__/ .flashrank_cache surf_new_backend.egg-info/ -/podcasts/ +podcasts/ video_presentation_audio/ sandbox_files/ temp_audio/ diff --git a/surfsense_backend/alembic/versions/156_unify_credits_wallet.py b/surfsense_backend/alembic/versions/156_unify_credits_wallet.py deleted file mode 100644 index 1ecf1a255..000000000 --- a/surfsense_backend/alembic/versions/156_unify_credits_wallet.py +++ /dev/null @@ -1,235 +0,0 @@ -"""unify page limits and premium credits into a single credit_micros_balance wallet - -Collapses the two separate economies (ETL ``pages_limit``/``pages_used`` and -premium ``premium_credit_micros_limit``/``premium_credit_micros_used``) into one -USD-micro wallet column ``user.credit_micros_balance`` that decreases on use and -increases on purchase / grant. ``premium_credit_micros_reserved`` is kept (renamed -to ``credit_micros_reserved``) for in-flight reservation holds. - -Backfill (per existing user row): - - balance = GREATEST(0, premium_credit_micros_limit - premium_credit_micros_used) - + (CASE WHEN pages_limit < 100000000 - THEN GREATEST(0, pages_limit - pages_used) * 1000 - ELSE 0 END) - -The ``pages_limit < 100000000`` guard skips the OSS "unlimited" default -(``PAGES_LIMIT=999999999``) so self-hosters don't get a ~$1M credit grant. -1 page == 1000 micros == $0.001 (matches the prior $1 / 1000 pages price). - -Table / type renames: - - premium_token_purchases -> credit_purchases - premiumtokenpurchasestatus (enum)-> creditpurchasestatus - user_incentive_tasks.pages_awarded -> credit_micros_awarded (backfilled * 1000) - -The "user" table is in zero_publication's column list, so this migration updates -the publication via ``apply_publication`` (canonical reconcile, per migration 155) -BEFORE dropping the old columns it referenced. - -IMPORTANT - before AND after running this migration (same as migration 140): - 1. Stop zero-cache (it holds replication locks that will deadlock DDL) - 2. Run: alembic upgrade head - 3. Delete / reset the zero-cache data volume - 4. Restart zero-cache (it will do a fresh initial sync) - -Revision ID: 156 -Revises: 155 -""" - -from collections.abc import Sequence - -import sqlalchemy as sa - -from alembic import op -from app.zero_publication import apply_publication - -revision: str = "156" -down_revision: str | None = "155" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def _column_exists(conn, table: str, column: str) -> bool: - return ( - conn.execute( - sa.text( - "SELECT 1 FROM information_schema.columns " - "WHERE table_name = :tbl AND column_name = :col " - "AND table_schema = current_schema()" - ), - {"tbl": table, "col": column}, - ).fetchone() - is not None - ) - - -def _table_exists(conn, table: str) -> bool: - return ( - conn.execute( - sa.text( - "SELECT 1 FROM information_schema.tables " - "WHERE table_name = :tbl AND table_schema = current_schema()" - ), - {"tbl": table}, - ).fetchone() - is not None - ) - - -def _terminate_blocked_pids(conn, table: str) -> None: - """Kill backends whose locks on *table* would block our AccessExclusiveLock.""" - conn.execute( - sa.text( - "SELECT pg_terminate_backend(l.pid) " - "FROM pg_locks l " - "JOIN pg_class c ON c.oid = l.relation " - "WHERE c.relname = :tbl " - " AND l.pid != pg_backend_pid()" - ), - {"tbl": table}, - ) - - -def upgrade() -> None: - conn = op.get_bind() - - # ------------------------------------------------------------------ - # 1. Add credit_micros_balance + backfill from both legacy economies. - # ------------------------------------------------------------------ - if not _column_exists(conn, "user", "credit_micros_balance"): - op.add_column( - "user", - sa.Column( - "credit_micros_balance", - sa.BigInteger(), - nullable=False, - server_default="5000000", - ), - ) - - # Backfill only when ALL legacy source columns are present (fresh DBs - # created from current models won't have them). - if all( - _column_exists(conn, "user", col) - for col in ( - "premium_credit_micros_limit", - "premium_credit_micros_used", - "pages_limit", - "pages_used", - ) - ): - conn.execute( - sa.text( - 'UPDATE "user" SET credit_micros_balance = ' - "GREATEST(0, premium_credit_micros_limit - premium_credit_micros_used) " - "+ (CASE WHEN pages_limit < 100000000 " - " THEN GREATEST(0, pages_limit - pages_used) * 1000 " - " ELSE 0 END)" - ) - ) - - # ------------------------------------------------------------------ - # 2. Rename premium_credit_micros_reserved -> credit_micros_reserved. - # ------------------------------------------------------------------ - if _column_exists( - conn, "user", "premium_credit_micros_reserved" - ) and not _column_exists(conn, "user", "credit_micros_reserved"): - op.alter_column( - "user", - "premium_credit_micros_reserved", - new_column_name="credit_micros_reserved", - ) - - # ------------------------------------------------------------------ - # 3. Reconcile the Zero publication to the new column list - # (id, credit_micros_balance) BEFORE dropping the columns it used - # to reference, otherwise Postgres rejects the column drops with - # "cannot drop column ... referenced by publication". - # ------------------------------------------------------------------ - conn.execute(sa.text("SET lock_timeout = '10s'")) - _terminate_blocked_pids(conn, "user") - apply_publication(conn) - - # ------------------------------------------------------------------ - # 4. Drop the legacy quota columns now that nothing references them. - # ------------------------------------------------------------------ - for col in ( - "premium_credit_micros_limit", - "premium_credit_micros_used", - "pages_limit", - "pages_used", - ): - if _column_exists(conn, "user", col): - op.drop_column("user", col) - - # ------------------------------------------------------------------ - # 5. Rename premium_token_purchases -> credit_purchases and its enum. - # ------------------------------------------------------------------ - op.execute( - """ - DO $$ - BEGIN - IF EXISTS ( - SELECT 1 FROM pg_type t - JOIN pg_namespace n ON n.oid = t.typnamespace - WHERE t.typname = 'premiumtokenpurchasestatus' - AND n.nspname = current_schema() - ) - AND NOT EXISTS ( - SELECT 1 FROM pg_type t - JOIN pg_namespace n ON n.oid = t.typnamespace - WHERE t.typname = 'creditpurchasestatus' - AND n.nspname = current_schema() - ) - THEN - ALTER TYPE premiumtokenpurchasestatus RENAME TO creditpurchasestatus; - END IF; - END - $$; - """ - ) - - if _table_exists(conn, "premium_token_purchases") and not _table_exists( - conn, "credit_purchases" - ): - op.rename_table("premium_token_purchases", "credit_purchases") - - # ``source`` distinguishes user checkout from auto-reload top-ups. - if _table_exists(conn, "credit_purchases") and not _column_exists( - conn, "credit_purchases", "source" - ): - op.add_column( - "credit_purchases", - sa.Column( - "source", - sa.String(length=20), - nullable=False, - server_default="checkout", - ), - ) - - # ------------------------------------------------------------------ - # 6. Rename user_incentive_tasks.pages_awarded -> credit_micros_awarded - # and convert page counts to micros (1 page == 1000 micros). - # ------------------------------------------------------------------ - if _column_exists( - conn, "user_incentive_tasks", "pages_awarded" - ) and not _column_exists(conn, "user_incentive_tasks", "credit_micros_awarded"): - op.alter_column( - "user_incentive_tasks", - "pages_awarded", - new_column_name="credit_micros_awarded", - type_=sa.BigInteger(), - ) - conn.execute( - sa.text( - "UPDATE user_incentive_tasks " - "SET credit_micros_awarded = credit_micros_awarded * 1000" - ) - ) - - -def downgrade() -> None: - """No-op. This is a one-way data-model unification; the legacy split - columns cannot be faithfully reconstructed from a single balance.""" diff --git a/surfsense_backend/alembic/versions/157_add_auto_reload_columns.py b/surfsense_backend/alembic/versions/157_add_auto_reload_columns.py deleted file mode 100644 index ef021b6d2..000000000 --- a/surfsense_backend/alembic/versions/157_add_auto_reload_columns.py +++ /dev/null @@ -1,92 +0,0 @@ -"""add auto-reload (off-session Stripe top-up) columns to user - -Adds the saved-card + threshold plumbing that powers feature-flagged credit -auto-reload (``AUTO_RELOAD_ENABLED``): - - user.stripe_customer_id (text, nullable) - user.auto_reload_enabled (bool, default false) - user.auto_reload_threshold_micros (bigint, nullable) - user.auto_reload_amount_micros (bigint, nullable) - user.auto_reload_payment_method_id (text, nullable) - user.auto_reload_failed_at (timestamptz, nullable) - -None of these columns are part of the Zero publication (``USER_COLS`` is -``["id", "credit_micros_balance"]``), so this migration does NOT touch the -publication and is safe to run without the zero-cache stop/reset dance that -migration 156 required. - -The ``credit_purchases.source`` column (``checkout`` | ``auto_reload``) was -already added in migration 156, so it is not repeated here. - -Revision ID: 157 -Revises: 156 -""" - -from collections.abc import Sequence - -import sqlalchemy as sa - -from alembic import op - -revision: str = "157" -down_revision: str | None = "156" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def _column_exists(conn, table: str, column: str) -> bool: - return ( - conn.execute( - sa.text( - "SELECT 1 FROM information_schema.columns " - "WHERE table_name = :tbl AND column_name = :col " - "AND table_schema = current_schema()" - ), - {"tbl": table, "col": column}, - ).fetchone() - is not None - ) - - -_COLUMNS: list[tuple[str, sa.Column]] = [ - ("stripe_customer_id", sa.Column("stripe_customer_id", sa.String(), nullable=True)), - ( - "auto_reload_enabled", - sa.Column( - "auto_reload_enabled", - sa.Boolean(), - nullable=False, - server_default=sa.text("false"), - ), - ), - ( - "auto_reload_threshold_micros", - sa.Column("auto_reload_threshold_micros", sa.BigInteger(), nullable=True), - ), - ( - "auto_reload_amount_micros", - sa.Column("auto_reload_amount_micros", sa.BigInteger(), nullable=True), - ), - ( - "auto_reload_payment_method_id", - sa.Column("auto_reload_payment_method_id", sa.String(), nullable=True), - ), - ( - "auto_reload_failed_at", - sa.Column("auto_reload_failed_at", sa.TIMESTAMP(timezone=True), nullable=True), - ), -] - - -def upgrade() -> None: - conn = op.get_bind() - for name, column in _COLUMNS: - if not _column_exists(conn, "user", name): - op.add_column("user", column) - - -def downgrade() -> None: - conn = op.get_bind() - for name, _ in reversed(_COLUMNS): - if _column_exists(conn, "user", name): - op.drop_column("user", name) diff --git a/surfsense_backend/alembic/versions/158_evolve_podcasts_lifecycle.py b/surfsense_backend/alembic/versions/158_evolve_podcasts_lifecycle.py deleted file mode 100644 index f3b194cbd..000000000 --- a/surfsense_backend/alembic/versions/158_evolve_podcasts_lifecycle.py +++ /dev/null @@ -1,118 +0,0 @@ -"""evolve podcasts: expand status lifecycle and add brief/transcript/storage columns - -Revision ID: 158 -Revises: 157 -""" - -from collections.abc import Sequence - -import sqlalchemy as sa - -from alembic import op - -revision: str = "158" -down_revision: str | None = "157" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def _drop_podcasts_from_publication() -> None: - """Detach podcasts from zero_publication so status can be retyped. - - Postgres refuses ``ALTER COLUMN ... TYPE`` on a column a publication - depends on. Some databases reach this migration with podcasts already - published (an interim apply_publication ran during 156); drop it here and - let migration 159 reconcile the publication to the canonical shape. - """ - conn = op.get_bind() - published = conn.execute( - sa.text( - "SELECT 1 FROM pg_publication_tables " - "WHERE pubname = 'zero_publication' " - "AND schemaname = current_schema() AND tablename = 'podcasts'" - ) - ).fetchone() - if published: - op.execute('ALTER PUBLICATION "zero_publication" DROP TABLE "podcasts";') - - -def upgrade() -> None: - _drop_podcasts_from_publication() - - # Retype the status enum by swapping in a fresh type and casting existing - # rows. The legacy transient value 'generating' maps onto 'rendering'. - op.execute("ALTER TYPE podcast_status RENAME TO podcast_status_old;") - op.execute( - """ - CREATE TYPE podcast_status AS ENUM ( - 'pending', 'awaiting_brief', 'drafting', 'awaiting_review', - 'rendering', 'ready', 'failed', 'cancelled' - ); - """ - ) - op.execute("ALTER TABLE podcasts ALTER COLUMN status DROP DEFAULT;") - op.execute( - """ - ALTER TABLE podcasts - ALTER COLUMN status TYPE podcast_status - USING ( - CASE status::text - WHEN 'generating' THEN 'rendering' - ELSE status::text - END - )::podcast_status; - """ - ) - op.execute("ALTER TABLE podcasts ALTER COLUMN status SET DEFAULT 'pending';") - op.execute("DROP TYPE podcast_status_old;") - - op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS source_content TEXT;") - op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS spec JSONB;") - op.execute( - "ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS spec_version " - "INTEGER NOT NULL DEFAULT 1;" - ) - op.execute( - "ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS storage_backend VARCHAR(32);" - ) - op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS storage_key TEXT;") - op.execute( - "ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS duration_seconds INTEGER;" - ) - op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS error TEXT;") - - -def downgrade() -> None: - op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS error;") - op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS duration_seconds;") - op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS storage_key;") - op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS storage_backend;") - op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS spec_version;") - op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS spec;") - op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS source_content;") - - # Collapse the expanded lifecycle back onto the original four values. - op.execute("ALTER TYPE podcast_status RENAME TO podcast_status_new;") - op.execute( - "CREATE TYPE podcast_status AS ENUM " - "('pending', 'generating', 'ready', 'failed');" - ) - op.execute("ALTER TABLE podcasts ALTER COLUMN status DROP DEFAULT;") - op.execute( - """ - ALTER TABLE podcasts - ALTER COLUMN status TYPE podcast_status - USING ( - CASE status::text - WHEN 'awaiting_brief' THEN 'pending' - WHEN 'drafting' THEN 'generating' - WHEN 'awaiting_review' THEN 'generating' - WHEN 'rendering' THEN 'generating' - WHEN 'cancelled' THEN 'failed' - ELSE status::text - END - )::podcast_status; - """ - ) - op.execute("ALTER TABLE podcasts ALTER COLUMN status SET DEFAULT 'ready';") - op.execute("DROP TYPE podcast_status_new;") diff --git a/surfsense_backend/alembic/versions/159_publish_podcasts_to_zero.py b/surfsense_backend/alembic/versions/159_publish_podcasts_to_zero.py deleted file mode 100644 index 1667ca96b..000000000 --- a/surfsense_backend/alembic/versions/159_publish_podcasts_to_zero.py +++ /dev/null @@ -1,26 +0,0 @@ -"""publish podcasts to zero_publication - -Reconciles ``zero_publication`` after migration 158 added the lifecycle columns, -so the frontend observes podcast status and the reviewable brief by push. - -Revision ID: 159 -Revises: 158 -""" - -from collections.abc import Sequence - -from alembic import op -from app.zero_publication import apply_publication - -revision: str = "159" -down_revision: str | None = "158" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def upgrade() -> None: - apply_publication(op.get_bind()) - - -def downgrade() -> None: - """No-op. Historical publication shapes are immutable.""" diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/routing.md b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/routing.md index aa6217041..28cf0ac63 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/routing.md +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/system_prompt/prompts/routing.md @@ -126,25 +126,23 @@ user: "Create issues in Linear for each of these five bugs: " user: "Make a 30-second podcast of this conversation." -→ Podcast deliverable. The `deliverables` subagent sets the podcast up and - returns **immediately** — generation does not happen during the call. A - live card in the chat takes over from there: the user reviews the brief - (language, voices, length) on the card, and the episode drafts and - renders automatically after they approve. +→ Celery-backed deliverable. The `deliverables` subagent dispatches the + Celery job and then **waits for it to finish** before returning. The + call may take 10-60 seconds (or longer for video presentations) — + that is intentional, not a hang. You always get back one of two + Receipt shapes: task(deliverables, "Generate a podcast titled '' from the - following content. Aim for a 30-second style brief. Return the - podcast id and title.\n\n<source content>") + following content. Use a 30-second style brief. Return the podcast + id and title.\n\n<source content>") Outcomes: - - **`status="success"`**: the podcast is set up. Do NOT describe its - current status or promise it is ready — the card tracks progress - live and will outlive whatever you say. Just point the user at the - card in the chat. + - **`status="success"`**: the audio is saved. Tell the user the + podcast is **ready** and quote the `external_id` / `preview` so + they can find it in the podcast panel. - **`status="failed"`**: surface the Receipt's `error` field verbatim. Do NOT silently re-dispatch — the backend already tried and reported a real error. - Video presentations differ: that Celery-backed call **waits for the - render to finish** before returning (possibly minutes — intentional, - not a hang) and ends with a terminal status. If a + Same two-way pattern applies to video presentations (which take + longer to render, but still return a terminal status). If a `task(deliverables, ...)` invocation itself times out at the subagent layer (separate from the Receipt), that's an operator-side problem with the subagent invoke timeout, not a deliverable failure — pass diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/podcast.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/podcast.py index d8d28ceb1..bfa3cc100 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/podcast.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/podcast.py @@ -1,10 +1,11 @@ """Factory for a podcast-generation tool. -Creates the podcast and proposes its brief (language, voices, length) inline, -then returns immediately with the row awaiting review. Everything after — -brief approval, drafting, rendering — happens on the live podcast card, so -this tool never blocks on generation and the chat text must not describe a -status that the card will outgrow. +Dispatches the heavy generation to Celery and then polls the podcast row +until it reaches a terminal status (READY/FAILED). The tool always +returns a real terminal ``Receipt`` — never a pending one. The wait is +bounded by the existing per-invocation safety net +(``SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS`` in multi-agent mode, +HTTP / process lifetime in single-agent mode). """ import logging @@ -17,12 +18,13 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt +from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.deliverable_wait import ( + wait_for_deliverable, +) from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools.thread_resolver import ( resolve_root_thread_id, ) -from app.db import PodcastStatus, shielded_async_session -from app.podcasts.generation.brief import propose_brief -from app.podcasts.service import PodcastService +from app.db import Podcast, PodcastStatus, shielded_async_session logger = logging.getLogger(__name__) @@ -43,7 +45,7 @@ def create_generate_podcast_tool( user_prompt: str | None = None, ) -> Command: """ - Prepare a podcast from the provided content for the user to review. + Generate a podcast from the provided content. Use this tool when the user asks to create, generate, or make a podcast. Common triggers include phrases like: @@ -53,59 +55,100 @@ def create_generate_podcast_tool( - "Make a podcast about..." - "Turn this into a podcast" - This sets up the podcast and proposes its brief (language, voices, - length). The user reviews the brief on the live podcast card in the - chat; after approval the episode drafts and renders automatically. - Generation does not start here, and the card tracks all progress — do - not describe the podcast's current status in your reply. - Args: source_content: The text content to convert into a podcast. podcast_title: Title for the podcast (default: "SurfSense Podcast") - user_prompt: Optional steer for what the episode should focus on. + user_prompt: Optional instructions for podcast style, tone, or format. Returns: A dictionary containing: - - status: the podcast lifecycle status (awaiting_brief on success) - - podcast_id: the podcast ID to review in the panel - - title: the podcast title - - message: what the user should do next (or "error" when failed) + - status: PodcastStatus value (pending, generating, or failed) + - podcast_id: The podcast ID for polling (when status is pending or generating) + - title: The podcast title + - message: Status message (or "error" field if status is failed) """ try: # One DB session per tool call so parallel invocations never share an AsyncSession. async with shielded_async_session() as session: - service = PodcastService(session) - podcast = await service.create( + podcast = Podcast( title=podcast_title, + status=PodcastStatus.PENDING, search_space_id=search_space_id, thread_id=resolve_root_thread_id(runtime, thread_id), ) - podcast.source_content = source_content - spec = await propose_brief( - session, - search_space_id=search_space_id, - focus=user_prompt, - ) - await service.attach_brief(podcast, spec) + session.add(podcast) await session.commit() + await session.refresh(podcast) podcast_id = podcast.id - logger.info( - "[generate_podcast] Prepared podcast %s awaiting brief review", - podcast_id, + from app.tasks.celery_tasks.podcast_tasks import ( + generate_content_podcast_task, ) - payload: dict[str, Any] = { - "status": PodcastStatus.AWAITING_BRIEF.value, + task = generate_content_podcast_task.delay( + podcast_id=podcast_id, + source_content=source_content, + search_space_id=search_space_id, + user_prompt=user_prompt, + ) + + logger.info( + "[generate_podcast] Created podcast %s, task: %s", + podcast_id, + task.id, + ) + + # Wait until the Celery worker flips the row to a terminal + # state. The wait is bounded only by the subagent invoke + # timeout (multi-agent) or HTTP lifetime (single-agent) — + # see app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.deliverable_wait for details. + terminal_status, columns, elapsed = await wait_for_deliverable( + model=Podcast, + row_id=podcast_id, + columns=[Podcast.status, Podcast.file_location], + terminal_statuses={PodcastStatus.READY, PodcastStatus.FAILED}, + ) + + if terminal_status == PodcastStatus.READY: + file_location = columns[1] if columns else None + logger.info( + "[generate_podcast] Podcast %s READY in %.2fs (file=%s)", + podcast_id, + elapsed, + file_location, + ) + payload: dict[str, Any] = { + "status": PodcastStatus.READY.value, + "podcast_id": podcast_id, + "title": podcast_title, + "file_location": file_location, + "message": ("Podcast generated and saved to your podcast panel."), + } + return with_receipt( + payload=payload, + receipt=make_receipt( + route="deliverables", + type="podcast", + operation="generate", + status="success", + external_id=str(podcast_id), + preview=podcast_title, + ), + tool_call_id=runtime.tool_call_id, + ) + + # Only other terminal state is FAILED. + logger.warning( + "[generate_podcast] Podcast %s FAILED in %.2fs", + podcast_id, + elapsed, + ) + err = "Background worker reported FAILED status for this podcast." + payload = { + "status": PodcastStatus.FAILED.value, "podcast_id": podcast_id, "title": podcast_title, - "message": ( - "Podcast set up. The card in the chat handles the rest: " - "the user reviews the brief (language, voices, length) " - "there, and the episode drafts and renders automatically " - "after approval. The card tracks progress live, so do not " - "state the podcast's current status in your reply." - ), + "error": err, } return with_receipt( payload=payload, @@ -113,9 +156,10 @@ def create_generate_podcast_tool( route="deliverables", type="podcast", operation="generate", - status="success", + status="failed", external_id=str(podcast_id), preview=podcast_title, + error=err, ), tool_call_id=runtime.tool_call_id, ) diff --git a/surfsense_backend/app/agents/podcaster/__init__.py b/surfsense_backend/app/agents/podcaster/__init__.py new file mode 100644 index 000000000..8459b2977 --- /dev/null +++ b/surfsense_backend/app/agents/podcaster/__init__.py @@ -0,0 +1,8 @@ +"""New LangGraph Agent. + +This module defines a custom graph. +""" + +from .graph import graph + +__all__ = ["graph"] diff --git a/surfsense_backend/app/agents/podcaster/configuration.py b/surfsense_backend/app/agents/podcaster/configuration.py new file mode 100644 index 000000000..6a903f9df --- /dev/null +++ b/surfsense_backend/app/agents/podcaster/configuration.py @@ -0,0 +1,29 @@ +"""Define the configurable parameters for the agent.""" + +from __future__ import annotations + +from dataclasses import dataclass, fields + +from langchain_core.runnables import RunnableConfig + + +@dataclass(kw_only=True) +class Configuration: + """The configuration for the agent.""" + + # Changeme: Add configurable values here! + # these values can be pre-set when you + # create assistants (https://langchain-ai.github.io/langgraph/cloud/how-tos/configuration_cloud/) + # and when you invoke the graph + podcast_title: str + search_space_id: int + user_prompt: str | None = None + + @classmethod + def from_runnable_config( + cls, config: RunnableConfig | None = None + ) -> Configuration: + """Create a Configuration instance from a RunnableConfig object.""" + configurable = (config.get("configurable") or {}) if config else {} + _fields = {f.name for f in fields(cls) if f.init} + return cls(**{k: v for k, v in configurable.items() if k in _fields}) diff --git a/surfsense_backend/app/agents/podcaster/graph.py b/surfsense_backend/app/agents/podcaster/graph.py new file mode 100644 index 000000000..94045566b --- /dev/null +++ b/surfsense_backend/app/agents/podcaster/graph.py @@ -0,0 +1,29 @@ +from langgraph.graph import StateGraph + +from .configuration import Configuration +from .nodes import create_merged_podcast_audio, create_podcast_transcript +from .state import State + + +def build_graph(): + # Define a new graph + workflow = StateGraph(State, config_schema=Configuration) + + # Add the node to the graph + workflow.add_node("create_podcast_transcript", create_podcast_transcript) + workflow.add_node("create_merged_podcast_audio", create_merged_podcast_audio) + + # Set the entrypoint as `call_model` + workflow.add_edge("__start__", "create_podcast_transcript") + workflow.add_edge("create_podcast_transcript", "create_merged_podcast_audio") + workflow.add_edge("create_merged_podcast_audio", "__end__") + + # Compile the workflow into an executable graph + graph = workflow.compile() + graph.name = "Surfsense Podcaster" # This defines the custom name in LangSmith + + return graph + + +# Compile the graph once when the module is loaded +graph = build_graph() diff --git a/surfsense_backend/app/agents/podcaster/nodes.py b/surfsense_backend/app/agents/podcaster/nodes.py new file mode 100644 index 000000000..d1f140a44 --- /dev/null +++ b/surfsense_backend/app/agents/podcaster/nodes.py @@ -0,0 +1,195 @@ +import asyncio +import json +import os +import uuid +from pathlib import Path +from typing import Any + +from ffmpeg.asyncio import FFmpeg +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.runnables import RunnableConfig +from litellm import aspeech + +from app.config import config as app_config +from app.services.kokoro_tts_service import get_kokoro_tts_service +from app.services.llm_service import get_agent_llm +from app.utils.content_utils import extract_text_content, strip_markdown_fences + +from .configuration import Configuration +from .prompts import get_podcast_generation_prompt +from .state import PodcastTranscriptEntry, PodcastTranscripts, State +from .utils import get_voice_for_provider + + +async def create_podcast_transcript( + state: State, config: RunnableConfig +) -> dict[str, Any]: + """Generate the podcast transcript from the source content.""" + configuration = Configuration.from_runnable_config(config) + search_space_id = configuration.search_space_id + user_prompt = configuration.user_prompt + + llm = await get_agent_llm(state.db_session, search_space_id) + if not llm: + error_message = f"No agent LLM configured for search space {search_space_id}" + print(error_message) + raise RuntimeError(error_message) + + prompt = get_podcast_generation_prompt(user_prompt) + messages = [ + SystemMessage(content=prompt), + HumanMessage( + content=f"<source_content>{state.source_content}</source_content>" + ), + ] + llm_response = await llm.ainvoke(messages) + + # Reasoning models may return content as blocks; normalise to a string. + content = strip_markdown_fences(extract_text_content(llm_response.content)) + + try: + podcast_transcript = PodcastTranscripts.model_validate(json.loads(content)) + except (json.JSONDecodeError, TypeError, ValueError) as e: + print(f"Direct JSON parsing failed, trying fallback approach: {e!s}") + + try: + json_start = content.find("{") + json_end = content.rfind("}") + 1 + if json_start >= 0 and json_end > json_start: + json_str = content[json_start:json_end] + parsed_data = json.loads(json_str) + podcast_transcript = PodcastTranscripts.model_validate(parsed_data) + print("Successfully parsed podcast transcript using fallback approach") + else: + error_message = f"Could not find valid JSON in LLM response. Raw response: {content}" + print(error_message) + raise ValueError(error_message) + + except (json.JSONDecodeError, TypeError, ValueError) as e2: + error_message = f"Error parsing LLM response (fallback also failed): {e2!s}" + print(f"Error parsing LLM response: {e2!s}") + print(f"Raw response: {content}") + raise + + return {"podcast_transcript": podcast_transcript.podcast_transcripts} + + +async def create_merged_podcast_audio( + state: State, config: RunnableConfig +) -> dict[str, Any]: + """Generate audio for each transcript and merge them into a single podcast file.""" + starting_transcript = PodcastTranscriptEntry( + speaker_id=1, dialog="Welcome to Surfsense Podcast." + ) + + transcript = state.podcast_transcript + + # transcript may be a PodcastTranscripts object or already a list. + if hasattr(transcript, "podcast_transcripts"): + transcript_entries = transcript.podcast_transcripts + else: + transcript_entries = transcript + + merged_transcript = [starting_transcript, *transcript_entries] + + temp_dir = Path("temp_audio") + temp_dir.mkdir(exist_ok=True) + + session_id = str(uuid.uuid4()) + output_path = f"podcasts/{session_id}_podcast.mp3" + os.makedirs("podcasts", exist_ok=True) + + audio_files = [] + + async def generate_speech_for_segment(segment, index): + if hasattr(segment, "speaker_id"): + speaker_id = segment.speaker_id + dialog = segment.dialog + else: + speaker_id = segment.get("speaker_id", 0) + dialog = segment.get("dialog", "") + + voice = get_voice_for_provider(app_config.TTS_SERVICE, speaker_id) + + if app_config.TTS_SERVICE == "local/kokoro": + filename = f"{temp_dir}/{session_id}_{index}.wav" + else: + filename = f"{temp_dir}/{session_id}_{index}.mp3" + + try: + if app_config.TTS_SERVICE == "local/kokoro": + kokoro_service = await get_kokoro_tts_service( + lang_code="a" + ) # American English + audio_path = await kokoro_service.generate_speech( + text=dialog, voice=voice, speed=1.0, output_path=filename + ) + return audio_path + else: + if app_config.TTS_SERVICE_API_BASE: + response = await aspeech( + model=app_config.TTS_SERVICE, + api_base=app_config.TTS_SERVICE_API_BASE, + api_key=app_config.TTS_SERVICE_API_KEY, + voice=voice, + input=dialog, + max_retries=2, + timeout=600, + ) + else: + response = await aspeech( + model=app_config.TTS_SERVICE, + api_key=app_config.TTS_SERVICE_API_KEY, + voice=voice, + input=dialog, + max_retries=2, + timeout=600, + ) + + with open(filename, "wb") as f: + f.write(response.content) + + return filename + except Exception as e: + print(f"Error generating speech for segment {index}: {e!s}") + raise + + tasks = [ + generate_speech_for_segment(segment, i) + for i, segment in enumerate(merged_transcript) + ] + audio_files = await asyncio.gather(*tasks) + + try: + ffmpeg = FFmpeg().option("y") + for audio_file in audio_files: + ffmpeg = ffmpeg.input(audio_file) + + filter_complex = [] + for i in range(len(audio_files)): + filter_complex.append(f"[{i}:0]") + + filter_complex_str = ( + "".join(filter_complex) + f"concat=n={len(audio_files)}:v=0:a=1[outa]" + ) + ffmpeg = ffmpeg.option("filter_complex", filter_complex_str) + ffmpeg = ffmpeg.output(output_path, map="[outa]") + await ffmpeg.execute() + + print(f"Successfully created podcast audio: {output_path}") + + except Exception as e: + print(f"Error merging audio files: {e!s}") + raise + finally: + for audio_file in audio_files: + try: + os.remove(audio_file) + except Exception as e: + print(f"Error removing audio file {audio_file}: {e!s}") + pass + + return { + "podcast_transcript": merged_transcript, + "final_podcast_file_path": output_path, + } diff --git a/surfsense_backend/app/agents/podcaster/prompts.py b/surfsense_backend/app/agents/podcaster/prompts.py new file mode 100644 index 000000000..efaa79788 --- /dev/null +++ b/surfsense_backend/app/agents/podcaster/prompts.py @@ -0,0 +1,122 @@ +import datetime + + +def get_podcast_generation_prompt(user_prompt: str | None = None): + return f""" +Today's date: {datetime.datetime.now().strftime("%Y-%m-%d")} +<podcast_generation_system> +You are a master podcast scriptwriter, adept at transforming diverse input content into a lively, engaging, and natural-sounding conversation between two distinct podcast hosts. Your primary objective is to craft authentic, flowing dialogue that captures the spontaneity and chemistry of a real podcast discussion, completely avoiding any hint of robotic scripting or stiff formality. Think dynamic interplay, not just information delivery. + +{ + f''' +You **MUST** strictly adhere to the following user instruction while generating the podcast script: +<user_instruction> +{user_prompt} +</user_instruction> +''' + if user_prompt + else "" + } + +<input> +- '<source_content>': A block of text containing the information to be discussed in the podcast. This could be research findings, an article summary, a detailed outline, user chat history related to the topic, or any other relevant raw information. The content might be unstructured but serves as the factual basis for the podcast dialogue. +</input> + +<output_format> +A JSON object containing the podcast transcript with alternating speakers: +{{ + "podcast_transcripts": [ + {{ + "speaker_id": 0, + "dialog": "Speaker 0 dialog here" + }}, + {{ + "speaker_id": 1, + "dialog": "Speaker 1 dialog here" + }}, + {{ + "speaker_id": 0, + "dialog": "Speaker 0 dialog here" + }}, + {{ + "speaker_id": 1, + "dialog": "Speaker 1 dialog here" + }} + ] +}} +</output_format> + +<guidelines> +1. **Establish Distinct & Consistent Host Personas:** + * **Speaker 0 (Lead Host):** Drives the conversation forward, introduces segments, poses key questions derived from the source content, and often summarizes takeaways. Maintain a guiding, clear, and engaging tone. + * **Speaker 1 (Co-Host/Expert):** Offers deeper insights, provides alternative viewpoints or elaborations on the source content, asks clarifying or challenging questions, and shares relevant anecdotes or examples. Adopt a complementary tone (e.g., analytical, enthusiastic, reflective, slightly skeptical). + * **Consistency is Key:** Ensure each speaker maintains their distinct voice, vocabulary choice, sentence structure, and perspective throughout the entire script. Avoid having them sound interchangeable. Their interaction should feel like a genuine partnership. + +2. **Craft Natural & Dynamic Dialogue:** + * **Emulate Real Conversation:** Use contractions (e.g., "don't", "it's"), interjections ("Oh!", "Wow!", "Hmm"), discourse markers ("you know", "right?", "well"), and occasional natural pauses or filler words. Avoid overly formal language or complex sentence structures typical of written text. + * **Foster Interaction & Chemistry:** Write dialogue where speakers genuinely react *to each other*. They should build on points ("Exactly, and that reminds me..."), ask follow-up questions ("Could you expand on that?"), express agreement/disagreement respectfully ("That's a fair point, but have you considered...?"), and show active listening. + * **Vary Rhythm & Pace:** Mix short, punchy lines with longer, more explanatory ones. Vary sentence beginnings. Use questions to break up exposition. The rhythm should feel spontaneous, not monotonous. + * **Inject Personality & Relatability:** Allow for appropriate humor, moments of surprise or curiosity, brief personal reflections ("I actually experienced something similar..."), or relatable asides that fit the hosts' personas and the topic. Lightly reference past discussions if it enhances context ("Remember last week when we touched on...?"). + +3. **Structure for Flow and Listener Engagement:** + * **Natural Beginning:** Start with dialogue that flows naturally after an introduction (which will be added manually). Avoid redundant greetings or podcast name mentions since these will be added separately. + * **Logical Progression & Signposting:** Guide the listener through the information smoothly. Use clear transitions to link different ideas or segments ("So, now that we've covered X, let's dive into Y...", "That actually brings me to another key finding..."). Ensure topics flow logically from one to the next. + * **Meaningful Conclusion:** Summarize the key takeaways or main points discussed, reinforcing the core message derived from the source content. End with a final thought, a lingering question for the audience, or a brief teaser for what's next, providing a sense of closure. Avoid abrupt endings. + +4. **Integrate Source Content Seamlessly & Accurately:** + * **Translate, Don't Recite:** Rephrase information from the `<source_content>` into conversational language suitable for each host's persona. Avoid directly copying dense sentences or technical jargon without explanation. The goal is discussion, not narration. + * **Explain & Contextualize:** Use analogies, simple examples, storytelling, or have one host ask clarifying questions (acting as a listener surrogate) to break down complex ideas from the source. + * **Weave Information Naturally:** Integrate facts, data, or key points from the source *within* the dialogue, not as standalone, undigested blocks. Attribute information conversationally where appropriate ("The research mentioned...", "Apparently, the key factor is..."). + * **Balance Depth & Accessibility:** Ensure the conversation is informative and factually accurate based on the source content, but prioritize clear communication and engaging delivery over exhaustive technical detail. Make it understandable and interesting for a general audience. + +5. **Length & Pacing:** + * **Six-Minute Duration:** Create a transcript that, when read at a natural speaking pace, would result in approximately 6 minutes of audio. Typically, this means around 1000 words total (based on average speaking rate of 150 words per minute). + * **Concise Speaking Turns:** Keep most speaking turns relatively brief and focused. Aim for a natural back-and-forth rhythm rather than extended monologues. + * **Essential Content Only:** Prioritize the most important information from the source content. Focus on quality over quantity, ensuring every line contributes meaningfully to the topic. +</guidelines> + +<examples> +Input: "Quantum computing uses quantum bits or qubits which can exist in multiple states simultaneously due to superposition." + +Output: +{{ + "podcast_transcripts": [ + {{ + "speaker_id": 0, + "dialog": "Today we're diving into the mind-bending world of quantum computing. You know, this is a topic I've been excited to cover for weeks." + }}, + {{ + "speaker_id": 1, + "dialog": "Same here! And I know our listeners have been asking for it. But I have to admit, the concept of quantum computing makes my head spin a little. Can we start with the basics?" + }}, + {{ + "speaker_id": 0, + "dialog": "Absolutely. So regular computers use bits, right? Little on-off switches that are either 1 or 0. But quantum computers use something called qubits, and this is where it gets fascinating." + }}, + {{ + "speaker_id": 1, + "dialog": "Wait, what makes qubits so special compared to regular bits?" + }}, + {{ + "speaker_id": 0, + "dialog": "The magic is in something called superposition. These qubits can exist in multiple states at the same time, not just 1 or 0." + }}, + {{ + "speaker_id": 1, + "dialog": "That sounds impossible! How would you even picture that?" + }}, + {{ + "speaker_id": 0, + "dialog": "Think of it like a coin spinning in the air. Before it lands, is it heads or tails?" + }}, + {{ + "speaker_id": 1, + "dialog": "Well, it's... neither? Or I guess both, until it lands? Oh, I think I see where you're going with this." + }} + ] +}} +</examples> + +Transform the source material into a lively and engaging podcast conversation. Craft dialogue that showcases authentic host chemistry and natural interaction (including occasional disagreement, building on points, or asking follow-up questions). Use varied speech patterns reflecting real human conversation, ensuring the final script effectively educates *and* entertains the listener while keeping within a 5-minute audio duration. +</podcast_generation_system> +""" diff --git a/surfsense_backend/app/agents/podcaster/state.py b/surfsense_backend/app/agents/podcaster/state.py new file mode 100644 index 000000000..62eb0537b --- /dev/null +++ b/surfsense_backend/app/agents/podcaster/state.py @@ -0,0 +1,43 @@ +"""Define the state structures for the agent.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from pydantic import BaseModel, Field +from sqlalchemy.ext.asyncio import AsyncSession + + +class PodcastTranscriptEntry(BaseModel): + """ + Represents a single entry in a podcast transcript. + """ + + speaker_id: int = Field(..., description="The ID of the speaker (0 or 1)") + dialog: str = Field(..., description="The dialog text spoken by the speaker") + + +class PodcastTranscripts(BaseModel): + """ + Represents the full podcast transcript structure. + """ + + podcast_transcripts: list[PodcastTranscriptEntry] = Field( + ..., description="List of transcript entries with alternating speakers" + ) + + +@dataclass +class State: + """Defines the input state for the agent, representing a narrower interface to the outside world. + + This class is used to define the initial state and structure of incoming data. + See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state + for more information. + """ + + # Runtime context + db_session: AsyncSession + source_content: str + podcast_transcript: list[PodcastTranscriptEntry] | None = None + final_podcast_file_path: str | None = None diff --git a/surfsense_backend/app/agents/podcaster/utils.py b/surfsense_backend/app/agents/podcaster/utils.py new file mode 100644 index 000000000..96ea1d51e --- /dev/null +++ b/surfsense_backend/app/agents/podcaster/utils.py @@ -0,0 +1,84 @@ +def get_voice_for_provider(provider: str, speaker_id: int) -> dict | str: + """ + Get the appropriate voice configuration based on the TTS provider and speaker ID. + + Args: + provider: The TTS provider (e.g., "openai/tts-1", "vertex_ai/test") + speaker_id: The ID of the speaker (0-5) + + Returns: + Voice configuration - string for OpenAI, dict for Vertex AI + """ + if provider == "local/kokoro": + # Kokoro voice mapping - https://huggingface.co/hexgrad/Kokoro-82M/tree/main/voices + kokoro_voices = { + 0: "am_adam", # Default/intro voice + 1: "af_bella", # First speaker + } + return kokoro_voices.get(speaker_id, "af_heart") + + # Extract provider type from the model string + provider_type = ( + provider.split("/")[0].lower() if "/" in provider else provider.lower() + ) + + if provider_type == "openai": + # OpenAI voice mapping - simple string values + openai_voices = { + 0: "alloy", # Default/intro voice + 1: "echo", # First speaker + 2: "fable", # Second speaker + 3: "onyx", # Third speaker + 4: "nova", # Fourth speaker + 5: "shimmer", # Fifth speaker + } + return openai_voices.get(speaker_id, "alloy") + + elif provider_type == "vertex_ai": + # Vertex AI voice mapping - dict with languageCode and name + vertex_voices = { + 0: { + "languageCode": "en-US", + "name": "en-US-Studio-O", + }, + 1: { + "languageCode": "en-US", + "name": "en-US-Studio-M", + }, + 2: { + "languageCode": "en-UK", + "name": "en-UK-Studio-A", + }, + 3: { + "languageCode": "en-UK", + "name": "en-UK-Studio-B", + }, + 4: { + "languageCode": "en-AU", + "name": "en-AU-Studio-A", + }, + 5: { + "languageCode": "en-AU", + "name": "en-AU-Studio-B", + }, + } + return vertex_voices.get(speaker_id, vertex_voices[0]) + elif provider_type == "azure": + # OpenAI voice mapping - simple string values + azure_voices = { + 0: "alloy", # Default/intro voice + 1: "echo", # First speaker + 2: "fable", # Second speaker + 3: "onyx", # Third speaker + 4: "nova", # Fourth speaker + 5: "shimmer", # Fifth speaker + } + return azure_voices.get(speaker_id, "alloy") + + else: + # Default fallback to OpenAI format for unknown providers + default_voices = { + 0: {}, + 1: {}, + } + return default_voices.get(speaker_id, default_voices[0]) diff --git a/surfsense_backend/app/agents/video_presentation/__init__.py b/surfsense_backend/app/agents/video_presentation/__init__.py index 8a51eb0ef..caf885218 100644 --- a/surfsense_backend/app/agents/video_presentation/__init__.py +++ b/surfsense_backend/app/agents/video_presentation/__init__.py @@ -1,7 +1,8 @@ """Video Presentation LangGraph Agent. -This module defines a graph for generating slide-based video presentations -from source content, with TTS narration per slide. +This module defines a graph for generating video presentations +from source content, similar to the podcaster agent but producing +slide-based video presentations with TTS narration. """ from .graph import graph diff --git a/surfsense_backend/app/celery_app.py b/surfsense_backend/app/celery_app.py index 5eebffd65..0e852b801 100644 --- a/surfsense_backend/app/celery_app.py +++ b/surfsense_backend/app/celery_app.py @@ -181,8 +181,7 @@ celery_app = Celery( backend=CELERY_RESULT_BACKEND, include=[ "app.tasks.celery_tasks.document_tasks", - "app.podcasts.tasks.draft", - "app.podcasts.tasks.render", + "app.tasks.celery_tasks.podcast_tasks", "app.tasks.celery_tasks.video_presentation_tasks", "app.tasks.celery_tasks.connector_tasks", "app.tasks.celery_tasks.obsidian_tasks", @@ -190,7 +189,6 @@ celery_app = Celery( "app.tasks.celery_tasks.document_reindex_tasks", "app.tasks.celery_tasks.stale_notification_cleanup_task", "app.tasks.celery_tasks.stripe_reconciliation_task", - "app.tasks.celery_tasks.auto_reload_task", "app.tasks.celery_tasks.gateway_tasks", "app.automations.tasks.execute_run", "app.automations.triggers.builtin.schedule.selector", @@ -283,9 +281,16 @@ celery_app.conf.beat_schedule = { "expires": 60, # Task expires after 60 seconds if not picked up }, }, - # Reconcile Stripe credit purchases that were paid but remained pending - "reconcile-pending-stripe-credit-purchases": { - "task": "reconcile_pending_stripe_credit_purchases", + # Reconcile Stripe purchases that were paid but remained pending + "reconcile-pending-stripe-page-purchases": { + "task": "reconcile_pending_stripe_page_purchases", + "schedule": crontab(**stripe_reconciliation_schedule_params), + "options": { + "expires": 60, + }, + }, + "reconcile-pending-stripe-token-purchases": { + "task": "reconcile_pending_stripe_token_purchases", "schedule": crontab(**stripe_reconciliation_schedule_params), "options": { "expires": 60, diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index bbaf3ac55..75af17d11 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -640,9 +640,14 @@ class Config: ) GATEWAY_DISCORD_REDIRECT_URI = os.getenv("GATEWAY_DISCORD_REDIRECT_URI") - # Stripe checkout (shared secrets for the unified credit wallet) + # Stripe checkout for pay-as-you-go page packs STRIPE_SECRET_KEY = os.getenv("STRIPE_SECRET_KEY") STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET") + STRIPE_PRICE_ID = os.getenv("STRIPE_PRICE_ID") + STRIPE_PAGES_PER_UNIT = int(os.getenv("STRIPE_PAGES_PER_UNIT", "1000")) + STRIPE_PAGE_BUYING_ENABLED = ( + os.getenv("STRIPE_PAGE_BUYING_ENABLED", "TRUE").upper() == "TRUE" + ) STRIPE_RECONCILIATION_LOOKBACK_MINUTES = int( os.getenv("STRIPE_RECONCILIATION_LOOKBACK_MINUTES", "10") ) @@ -650,56 +655,27 @@ class Config: os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100") ) - # Unified credit wallet (micro-USD) settings. + # Premium credit (micro-USD) quota settings. # - # Storage unit is integer micro-USD (1_000_000 = $1.00). A single - # ``credit_micros_balance`` funds both ETL page processing and premium - # model calls. New users start with ``DEFAULT_CREDIT_MICROS_BALANCE`` - # ($5 by default). - # - # Legacy env names (``PREMIUM_CREDIT_MICROS_LIMIT`` / ``PREMIUM_TOKEN_LIMIT``, - # ``STRIPE_PREMIUM_TOKEN_PRICE_ID``, ``STRIPE_CREDIT_MICROS_PER_UNIT`` / - # ``STRIPE_TOKENS_PER_UNIT``, ``STRIPE_TOKEN_BUYING_ENABLED``) are still - # honoured as fall-backs for one release; deprecation warnings fire below. - DEFAULT_CREDIT_MICROS_BALANCE = int( - os.getenv("DEFAULT_CREDIT_MICROS_BALANCE") - or os.getenv("PREMIUM_CREDIT_MICROS_LIMIT") + # Storage unit is integer micro-USD (1_000_000 = $1.00). The legacy + # ``PREMIUM_TOKEN_LIMIT`` and ``STRIPE_TOKENS_PER_UNIT`` env vars are + # still honoured for one release as fall-back values — the prior + # $1-per-1M-tokens Stripe price means every existing value maps 1:1 + # to micros, so operators upgrading without changing their .env still + # get correct behaviour. A startup deprecation warning fires below if + # they're set. + PREMIUM_CREDIT_MICROS_LIMIT = int( + os.getenv("PREMIUM_CREDIT_MICROS_LIMIT") or os.getenv("PREMIUM_TOKEN_LIMIT", "5000000") ) - STRIPE_CREDIT_PRICE_ID = os.getenv("STRIPE_CREDIT_PRICE_ID") or os.getenv( - "STRIPE_PREMIUM_TOKEN_PRICE_ID" - ) + STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID") STRIPE_CREDIT_MICROS_PER_UNIT = int( os.getenv("STRIPE_CREDIT_MICROS_PER_UNIT") or os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000") ) - STRIPE_CREDIT_BUYING_ENABLED = ( - os.getenv("STRIPE_CREDIT_BUYING_ENABLED") - or os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE") - ).upper() == "TRUE" - - # ETL page processing debits the credit wallet only when enabled. Defaults - # to FALSE so self-hosted / OSS installs keep effectively-free ETL; hosted - # deployments set this TRUE. 1 page == ``MICROS_PER_PAGE`` micro-USD. - ETL_CREDIT_BILLING_ENABLED = ( - os.getenv("ETL_CREDIT_BILLING_ENABLED", "FALSE").upper() == "TRUE" + STRIPE_TOKEN_BUYING_ENABLED = ( + os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE" ) - MICROS_PER_PAGE = int(os.getenv("MICROS_PER_PAGE", "1000")) - - # Low-balance WARNING threshold (micro-USD). Surfaced by the quota service - # so the UI can nudge the user to top up / enable auto-reload. $0.50. - CREDIT_LOW_BALANCE_WARNING_MICROS = int( - os.getenv("CREDIT_LOW_BALANCE_WARNING_MICROS", "500000") - ) - - # Auto-reload (off-session Stripe top-up) feature flag and guards. - AUTO_RELOAD_ENABLED = os.getenv("AUTO_RELOAD_ENABLED", "FALSE").upper() == "TRUE" - # Minimum configurable reload amount (micro-USD). $1.00 to match pack pricing. - AUTO_RELOAD_MIN_AMOUNT_MICROS = int( - os.getenv("AUTO_RELOAD_MIN_AMOUNT_MICROS", "1000000") - ) - # Cooldown so a burst of debits can't fire multiple charges (minutes). - AUTO_RELOAD_COOLDOWN_MINUTES = int(os.getenv("AUTO_RELOAD_COOLDOWN_MINUTES", "10")) # Safety ceiling on the per-call premium reservation. ``stream_new_chat`` # estimates an upper-bound cost from ``litellm.get_model_info`` x the @@ -709,13 +685,14 @@ class Config: # reserve_tokens ≈ $0.36) with headroom. QUOTA_MAX_RESERVE_MICROS = int(os.getenv("QUOTA_MAX_RESERVE_MICROS", "1000000")) - if ( - os.getenv("PREMIUM_TOKEN_LIMIT") or os.getenv("PREMIUM_CREDIT_MICROS_LIMIT") - ) and not os.getenv("DEFAULT_CREDIT_MICROS_BALANCE"): + if os.getenv("PREMIUM_TOKEN_LIMIT") and not os.getenv( + "PREMIUM_CREDIT_MICROS_LIMIT" + ): print( - "Warning: PREMIUM_TOKEN_LIMIT / PREMIUM_CREDIT_MICROS_LIMIT are " - "deprecated; rename to DEFAULT_CREDIT_MICROS_BALANCE. The old keys " - "will be removed in a future release." + "Warning: PREMIUM_TOKEN_LIMIT is deprecated; rename to " + "PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping under the " + "current Stripe price). The old key will be removed in a " + "future release." ) if os.getenv("STRIPE_TOKENS_PER_UNIT") and not os.getenv( "STRIPE_CREDIT_MICROS_PER_UNIT" @@ -725,22 +702,6 @@ class Config: "STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping). " "The old key will be removed in a future release." ) - if os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID") and not os.getenv( - "STRIPE_CREDIT_PRICE_ID" - ): - print( - "Warning: STRIPE_PREMIUM_TOKEN_PRICE_ID is deprecated; rename to " - "STRIPE_CREDIT_PRICE_ID. The old key will be removed in a future " - "release." - ) - if os.getenv("STRIPE_TOKEN_BUYING_ENABLED") and not os.getenv( - "STRIPE_CREDIT_BUYING_ENABLED" - ): - print( - "Warning: STRIPE_TOKEN_BUYING_ENABLED is deprecated; rename to " - "STRIPE_CREDIT_BUYING_ENABLED. The old key will be removed in a " - "future release." - ) # Anonymous / no-login mode settings NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE" @@ -942,6 +903,9 @@ class Config: # ETL Service ETL_SERVICE = os.getenv("ETL_SERVICE") + # Pages limit for ETL services (default to very high number for OSS unlimited usage) + PAGES_LIMIT = int(os.getenv("PAGES_LIMIT", "999999999")) + if ETL_SERVICE == "UNSTRUCTURED": # Unstructured API Key UNSTRUCTURED_API_KEY = os.getenv("UNSTRUCTURED_API_KEY") diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 2d672131b..6117caecb 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -114,6 +114,13 @@ class SearchSourceConnectorType(StrEnum): COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR" +class PodcastStatus(StrEnum): + PENDING = "pending" + GENERATING = "generating" + READY = "ready" + FAILED = "failed" + + class VideoPresentationStatus(StrEnum): PENDING = "pending" GENERATING = "generating" @@ -313,7 +320,7 @@ class PagePurchaseStatus(StrEnum): FAILED = "failed" -class CreditPurchaseStatus(StrEnum): +class PremiumTokenPurchaseStatus(StrEnum): PENDING = "pending" COMPLETED = "completed" FAILED = "failed" @@ -325,27 +332,26 @@ INCENTIVE_TASKS_CONFIG = { IncentiveTaskType.GITHUB_STAR: { "title": "Star our GitHub repository", "description": "Show your support by starring SurfSense on GitHub", - # Credit reward in USD micro-units (1_000_000 == $1.00). $0.03. - "credit_micros_reward": 30000, + "pages_reward": 30, "action_url": "https://github.com/MODSetter/SurfSense", }, IncentiveTaskType.REDDIT_FOLLOW: { "title": "Join our Subreddit", "description": "Join the SurfSense community on Reddit", - "credit_micros_reward": 30000, + "pages_reward": 30, "action_url": "https://www.reddit.com/r/SurfSense/", }, IncentiveTaskType.DISCORD_JOIN: { "title": "Join our Discord", "description": "Join the SurfSense community on Discord", - "credit_micros_reward": 40000, + "pages_reward": 40, "action_url": "https://discord.gg/ejRNvftDp9", }, # Future tasks can be configured here: # IncentiveTaskType.GITHUB_ISSUE: { # "title": "Create an issue", # "description": "Help improve SurfSense by reporting bugs or suggesting features", - # "credit_micros_reward": 50000, + # "pages_reward": 50, # "action_url": "https://github.com/MODSetter/SurfSense/issues/new/choose", # }, } @@ -1530,6 +1536,41 @@ class Chunk(BaseModel, TimestampMixin): document = relationship("Document", back_populates="chunks") +class Podcast(BaseModel, TimestampMixin): + """Podcast model for storing generated podcasts.""" + + __tablename__ = "podcasts" + + title = Column(String(500), nullable=False) + podcast_transcript = Column(JSONB, nullable=True) + file_location = Column(Text, nullable=True) + status = Column( + SQLAlchemyEnum( + PodcastStatus, + name="podcast_status", + create_type=False, + values_callable=lambda x: [e.value for e in x], + ), + nullable=False, + default=PodcastStatus.READY, + server_default="ready", + index=True, + ) + + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False + ) + search_space = relationship("SearchSpace", back_populates="podcasts") + + thread_id = Column( + Integer, + ForeignKey("new_chat_threads.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) + thread = relationship("NewChatThread") + + class VideoPresentation(BaseModel, TimestampMixin): """Video presentation model for storing AI-generated video presentations. @@ -2028,7 +2069,7 @@ class UserIncentiveTask(BaseModel, TimestampMixin): """ Tracks completed incentive tasks for users. Each user can only complete each task type once. - When a task is completed, the user's credit_micros_balance is increased. + When a task is completed, the user's pages_limit is increased. """ __tablename__ = "user_incentive_tasks" @@ -2047,8 +2088,7 @@ class UserIncentiveTask(BaseModel, TimestampMixin): index=True, ) task_type = Column(SQLAlchemyEnum(IncentiveTaskType), nullable=False, index=True) - # Credit reward granted in USD micro-units (1_000_000 == $1.00). - credit_micros_awarded = Column(BigInteger, nullable=False) + pages_awarded = Column(Integer, nullable=False) completed_at = Column( TIMESTAMP(timezone=True), nullable=False, @@ -2091,18 +2131,18 @@ class PagePurchase(Base, TimestampMixin): user = relationship("User", back_populates="page_purchases") -class CreditPurchase(Base, TimestampMixin): - """Tracks Stripe checkout sessions used to grant credit (USD micro-units). +class PremiumTokenPurchase(Base, TimestampMixin): + """Tracks Stripe checkout sessions used to grant additional premium credit (USD micro-units). - Renamed from ``premium_token_purchases`` in migration 156 as part of the - unified-credits wallet. ``credit_micros_granted`` stores the USD-micro - amount added to ``user.credit_micros_balance`` on fulfillment. - - ``source`` distinguishes a user-initiated checkout from an automatic - off-session top-up (auto-reload), added in the auto-reload migration. + Note: the table name is preserved (``premium_token_purchases``) for + operational continuity even though the unit is now USD micro-credits + instead of raw tokens. The ``credit_micros_granted`` column replaced + the legacy ``tokens_granted`` in migration 140; the stored values + were not transformed because the prior $1 = 1M tokens Stripe price + makes the unit conversion 1:1 numerically. """ - __tablename__ = "credit_purchases" + __tablename__ = "premium_token_purchases" __allow_unmapped__ = True id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) @@ -2120,18 +2160,15 @@ class CreditPurchase(Base, TimestampMixin): credit_micros_granted = Column(BigInteger, nullable=False) amount_total = Column(Integer, nullable=True) currency = Column(String(10), nullable=True) - source = Column( - String(20), nullable=False, default="checkout", server_default="checkout" - ) status = Column( - SQLAlchemyEnum(CreditPurchaseStatus), + SQLAlchemyEnum(PremiumTokenPurchaseStatus), nullable=False, - default=CreditPurchaseStatus.PENDING, + default=PremiumTokenPurchaseStatus.PENDING, index=True, ) completed_at = Column(TIMESTAMP(timezone=True), nullable=True) - user = relationship("User", back_populates="credit_purchases") + user = relationship("User", back_populates="premium_token_purchases") class SearchSpaceRole(BaseModel, TimestampMixin): @@ -2411,40 +2448,33 @@ if config.AUTH_TYPE == "GOOGLE": back_populates="user", cascade="all, delete-orphan", ) - credit_purchases = relationship( - "CreditPurchase", + premium_token_purchases = relationship( + "PremiumTokenPurchase", back_populates="user", cascade="all, delete-orphan", ) - # Unified credit wallet (USD micro-units, 1_000_000 == $1.00). - # Decreases on use (ETL pages + premium model calls), increases on - # purchase / incentive grant / auto-reload. May dip slightly negative - # when an actual cost exceeds its pre-charge estimate; UI clamps at $0. - credit_micros_balance = Column( + # Page usage tracking for ETL services + pages_limit = Column( + Integer, + nullable=False, + default=config.PAGES_LIMIT, + server_default=str(config.PAGES_LIMIT), + ) + pages_used = Column(Integer, nullable=False, default=0, server_default="0") + + premium_credit_micros_limit = Column( BigInteger, nullable=False, - default=config.DEFAULT_CREDIT_MICROS_BALANCE, - server_default=str(config.DEFAULT_CREDIT_MICROS_BALANCE), + default=config.PREMIUM_CREDIT_MICROS_LIMIT, + server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT), ) - # In-flight reservation holds (released/settled at finalize). - credit_micros_reserved = Column( + premium_credit_micros_used = Column( BigInteger, nullable=False, default=0, server_default="0" ) - - # Auto-reload (off-session Stripe top-up), behind AUTO_RELOAD_ENABLED. - # ``stripe_customer_id`` + ``auto_reload_payment_method_id`` are the - # saved-card plumbing; thresholds are micro-USD. ``auto_reload_failed_at`` - # is set (and ``auto_reload_enabled`` flipped off) when an off-session - # charge is declined so the UI can prompt the user to fix their card. - stripe_customer_id = Column(String, nullable=True) - auto_reload_enabled = Column( - Boolean, nullable=False, default=False, server_default="false" + premium_credit_micros_reserved = Column( + BigInteger, nullable=False, default=0, server_default="0" ) - auto_reload_threshold_micros = Column(BigInteger, nullable=True) - auto_reload_amount_micros = Column(BigInteger, nullable=True) - auto_reload_payment_method_id = Column(String, nullable=True) - auto_reload_failed_at = Column(TIMESTAMP(timezone=True), nullable=True) # User profile from OAuth display_name = Column(String, nullable=True) @@ -2557,40 +2587,33 @@ else: back_populates="user", cascade="all, delete-orphan", ) - credit_purchases = relationship( - "CreditPurchase", + premium_token_purchases = relationship( + "PremiumTokenPurchase", back_populates="user", cascade="all, delete-orphan", ) - # Unified credit wallet (USD micro-units, 1_000_000 == $1.00). - # Decreases on use (ETL pages + premium model calls), increases on - # purchase / incentive grant / auto-reload. May dip slightly negative - # when an actual cost exceeds its pre-charge estimate; UI clamps at $0. - credit_micros_balance = Column( + # Page usage tracking for ETL services + pages_limit = Column( + Integer, + nullable=False, + default=config.PAGES_LIMIT, + server_default=str(config.PAGES_LIMIT), + ) + pages_used = Column(Integer, nullable=False, default=0, server_default="0") + + premium_credit_micros_limit = Column( BigInteger, nullable=False, - default=config.DEFAULT_CREDIT_MICROS_BALANCE, - server_default=str(config.DEFAULT_CREDIT_MICROS_BALANCE), + default=config.PREMIUM_CREDIT_MICROS_LIMIT, + server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT), ) - # In-flight reservation holds (released/settled at finalize). - credit_micros_reserved = Column( + premium_credit_micros_used = Column( BigInteger, nullable=False, default=0, server_default="0" ) - - # Auto-reload (off-session Stripe top-up), behind AUTO_RELOAD_ENABLED. - # ``stripe_customer_id`` + ``auto_reload_payment_method_id`` are the - # saved-card plumbing; thresholds are micro-USD. ``auto_reload_failed_at`` - # is set (and ``auto_reload_enabled`` flipped off) when an off-session - # charge is declined so the UI can prompt the user to fix their card. - stripe_customer_id = Column(String, nullable=True) - auto_reload_enabled = Column( - Boolean, nullable=False, default=False, server_default="false" + premium_credit_micros_reserved = Column( + BigInteger, nullable=False, default=0, server_default="0" ) - auto_reload_threshold_micros = Column(BigInteger, nullable=True) - auto_reload_amount_micros = Column(BigInteger, nullable=True) - auto_reload_payment_method_id = Column(String, nullable=True) - auto_reload_failed_at = Column(TIMESTAMP(timezone=True), nullable=True) # User profile (can be set manually for non-OAuth users) display_name = Column(String, nullable=True) @@ -2866,10 +2889,6 @@ from app.automations.persistence import ( # noqa: E402, F401 ) from app.file_storage.persistence import DocumentFile # noqa: E402, F401 from app.notifications.persistence import Notification # noqa: E402, F401 -from app.podcasts.persistence import ( # noqa: E402, F401 - Podcast, - PodcastStatus, -) engine = create_async_engine( DATABASE_URL, diff --git a/surfsense_backend/app/notifications/api/api.py b/surfsense_backend/app/notifications/api/api.py index 9a136ca7b..ddca09c66 100644 --- a/surfsense_backend/app/notifications/api/api.py +++ b/surfsense_backend/app/notifications/api/api.py @@ -275,7 +275,7 @@ async def list_notifications( query = query.where(unread_filter) count_query = count_query.where(unread_filter) elif filter == "errors": - error_filter = (Notification.type == "insufficient_credits") | ( + error_filter = (Notification.type == "page_limit_exceeded") | ( Notification.notification_metadata["status"].astext == "failed" ) query = query.where(error_filter) diff --git a/surfsense_backend/app/notifications/constants.py b/surfsense_backend/app/notifications/constants.py index 6fc13e3c7..e8bd8391d 100644 --- a/surfsense_backend/app/notifications/constants.py +++ b/surfsense_backend/app/notifications/constants.py @@ -12,7 +12,6 @@ CATEGORY_TYPES: dict[str, tuple[str, ...]] = { "connector_indexing", "connector_deletion", "document_processing", - "insufficient_credits", - "auto_reload_failed", + "page_limit_exceeded", ), } diff --git a/surfsense_backend/app/notifications/service/facade.py b/surfsense_backend/app/notifications/service/facade.py index 9f4ad50d0..63154301c 100644 --- a/surfsense_backend/app/notifications/service/facade.py +++ b/surfsense_backend/app/notifications/service/facade.py @@ -10,12 +10,11 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.notifications.persistence import Notification from app.notifications.service.handlers import ( - AutoReloadFailedNotificationHandler, CommentReplyNotificationHandler, ConnectorIndexingNotificationHandler, DocumentProcessingNotificationHandler, - InsufficientCreditsNotificationHandler, MentionNotificationHandler, + PageLimitNotificationHandler, ) logger = logging.getLogger(__name__) @@ -28,8 +27,7 @@ class NotificationService: document_processing = DocumentProcessingNotificationHandler() mention = MentionNotificationHandler() comment_reply = CommentReplyNotificationHandler() - insufficient_credits = InsufficientCreditsNotificationHandler() - auto_reload_failed = AutoReloadFailedNotificationHandler() + page_limit = PageLimitNotificationHandler() @staticmethod async def create_notification( diff --git a/surfsense_backend/app/notifications/service/handlers/__init__.py b/surfsense_backend/app/notifications/service/handlers/__init__.py index 1a6168e37..8c32dea3b 100644 --- a/surfsense_backend/app/notifications/service/handlers/__init__.py +++ b/surfsense_backend/app/notifications/service/handlers/__init__.py @@ -2,18 +2,16 @@ from __future__ import annotations -from .auto_reload_failed import AutoReloadFailedNotificationHandler from .comment_reply import CommentReplyNotificationHandler from .connector_indexing import ConnectorIndexingNotificationHandler from .document_processing import DocumentProcessingNotificationHandler -from .insufficient_credits import InsufficientCreditsNotificationHandler from .mention import MentionNotificationHandler +from .page_limit import PageLimitNotificationHandler __all__ = [ - "AutoReloadFailedNotificationHandler", "CommentReplyNotificationHandler", "ConnectorIndexingNotificationHandler", "DocumentProcessingNotificationHandler", - "InsufficientCreditsNotificationHandler", "MentionNotificationHandler", + "PageLimitNotificationHandler", ] diff --git a/surfsense_backend/app/notifications/service/handlers/auto_reload_failed.py b/surfsense_backend/app/notifications/service/handlers/auto_reload_failed.py deleted file mode 100644 index 0234a436d..000000000 --- a/surfsense_backend/app/notifications/service/handlers/auto_reload_failed.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Notifications for failed off-session credit auto-reload charges.""" - -from __future__ import annotations - -import logging -from uuid import UUID - -from sqlalchemy.ext.asyncio import AsyncSession - -from app.notifications.persistence import Notification -from app.notifications.service.base import BaseNotificationHandler -from app.notifications.service.messages import auto_reload_failed as msg - -logger = logging.getLogger(__name__) - - -class AutoReloadFailedNotificationHandler(BaseNotificationHandler): - """Notifications for declined auto-reload top-ups.""" - - def __init__(self): - super().__init__("auto_reload_failed") - - async def notify_auto_reload_failed( - self, - session: AsyncSession, - user_id: UUID, - amount_micros: int, - payment_intent_id: str | None = None, - reason: str | None = None, - ) -> Notification: - """Notify that an off-session auto-reload charge was declined. - - Not tied to a search space (``search_space_id`` is None); the action - links to the billing settings so the user can fix their card. - """ - op_id = msg.operation_id(payment_intent_id or "") - title, message = msg.summary(amount_micros, reason) - - return await self.find_or_create_notification( - session=session, - user_id=user_id, - operation_id=op_id, - title=title, - message=message, - search_space_id=None, - initial_metadata={ - "amount_micros": amount_micros, - "payment_intent_id": payment_intent_id, - "status": "failed", - "error_type": "auto_reload_failed", - "action_url": "/dashboard", - "action_label": "Update card", - }, - ) diff --git a/surfsense_backend/app/notifications/service/handlers/insufficient_credits.py b/surfsense_backend/app/notifications/service/handlers/page_limit.py similarity index 55% rename from surfsense_backend/app/notifications/service/handlers/insufficient_credits.py rename to surfsense_backend/app/notifications/service/handlers/page_limit.py index 46124f222..90722dc62 100644 --- a/surfsense_backend/app/notifications/service/handlers/insufficient_credits.py +++ b/surfsense_backend/app/notifications/service/handlers/page_limit.py @@ -1,4 +1,4 @@ -"""Notifications for running out of credit during document processing.""" +"""Notifications for exceeding the page limit.""" from __future__ import annotations @@ -9,42 +9,46 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.notifications.persistence import Notification from app.notifications.service.base import BaseNotificationHandler -from app.notifications.service.messages import insufficient_credits as msg +from app.notifications.service.messages import page_limit as msg logger = logging.getLogger(__name__) -class InsufficientCreditsNotificationHandler(BaseNotificationHandler): - """Notifications for running out of credit during document processing.""" +class PageLimitNotificationHandler(BaseNotificationHandler): + """Notifications for exceeding the page limit.""" def __init__(self): - super().__init__("insufficient_credits") + super().__init__("page_limit_exceeded") - async def notify_insufficient_credits( + async def notify_page_limit_exceeded( self, session: AsyncSession, user_id: UUID, document_name: str, document_type: str, search_space_id: int, - balance_micros: int, - required_micros: int, + pages_used: int, + pages_limit: int, + pages_to_add: int, ) -> Notification: - """Notify that a document was blocked by insufficient credit.""" + """Notify that a document was blocked by the page limit.""" operation_id = msg.operation_id(document_name, search_space_id) - title, message = msg.summary(document_name, balance_micros, required_micros) + title, message = msg.summary( + document_name, pages_used, pages_limit, pages_to_add + ) metadata = { "operation_id": operation_id, "document_name": document_name, "document_type": document_type, - "balance_micros": balance_micros, - "required_micros": required_micros, + "pages_used": pages_used, + "pages_limit": pages_limit, + "pages_to_add": pages_to_add, "status": "failed", - "error_type": "insufficient_credits", + "error_type": "page_limit_exceeded", # Where the inbox item links to. - "action_url": f"/dashboard/{search_space_id}/buy-more", - "action_label": "Buy credits", + "action_url": f"/dashboard/{search_space_id}/more-pages", + "action_label": "Upgrade Plan", } notification = Notification( @@ -59,7 +63,6 @@ class InsufficientCreditsNotificationHandler(BaseNotificationHandler): await session.commit() await session.refresh(notification) logger.info( - f"Created insufficient_credits notification {notification.id} " - f"for user {user_id}" + f"Created page_limit_exceeded notification {notification.id} for user {user_id}" ) return notification diff --git a/surfsense_backend/app/notifications/service/messages/auto_reload_failed.py b/surfsense_backend/app/notifications/service/messages/auto_reload_failed.py deleted file mode 100644 index 5af19623c..000000000 --- a/surfsense_backend/app/notifications/service/messages/auto_reload_failed.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Pure presentation logic for auto-reload-failure notifications.""" - -from __future__ import annotations - -from datetime import UTC, datetime - - -def operation_id(payment_intent_id: str) -> str: - """Build a unique id for an auto-reload-failure notification. - - Keyed on the failed PaymentIntent so retries of the same charge collapse - into a single inbox item rather than spamming the user. - """ - if payment_intent_id: - return f"auto_reload_failed_{payment_intent_id}" - timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S_%f") - return f"auto_reload_failed_{timestamp}" - - -def summary(amount_micros: int, reason: str | None) -> tuple[str, str]: - """Compute the title and message for a failed off-session auto-reload charge.""" - amount_usd = max(0, amount_micros) / 1_000_000 - title = "Auto-reload failed" - base = ( - f"We couldn't automatically add ${amount_usd:.2f} of credit because your " - "saved card was declined. Auto-reload has been turned off — update your " - "card and re-enable it to keep topping up automatically." - ) - if reason: - base = f"{base} (Reason: {reason}.)" - return title, base diff --git a/surfsense_backend/app/notifications/service/messages/insufficient_credits.py b/surfsense_backend/app/notifications/service/messages/insufficient_credits.py deleted file mode 100644 index fad26ad91..000000000 --- a/surfsense_backend/app/notifications/service/messages/insufficient_credits.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Pure presentation logic for insufficient-credit notifications.""" - -from __future__ import annotations - -import hashlib -from datetime import UTC, datetime - -from app.notifications.service.messages.text import truncate - - -def operation_id(document_name: str, search_space_id: int) -> str: - """Build a unique id for an insufficient-credits notification.""" - timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S_%f") - doc_hash = hashlib.md5(document_name.encode()).hexdigest()[:8] - return f"insufficient_credits_{search_space_id}_{timestamp}_{doc_hash}" - - -def summary( - document_name: str, balance_micros: int, required_micros: int -) -> tuple[str, str]: - """Compute the title and message for a blocked-by-insufficient-credits document.""" - display_name = truncate(document_name, 40) - title = f"Insufficient credits: {display_name}" - balance_usd = max(0, balance_micros) / 1_000_000 - required_usd = max(0, required_micros) / 1_000_000 - message = ( - f"This document costs about ${required_usd:.2f} to process but you have " - f"${balance_usd:.2f} of credit left. Add more credits to continue." - ) - return title, message diff --git a/surfsense_backend/app/notifications/service/messages/page_limit.py b/surfsense_backend/app/notifications/service/messages/page_limit.py new file mode 100644 index 000000000..54e5cbdec --- /dev/null +++ b/surfsense_backend/app/notifications/service/messages/page_limit.py @@ -0,0 +1,25 @@ +"""Pure presentation logic for page-limit notifications.""" + +from __future__ import annotations + +import hashlib +from datetime import UTC, datetime + +from app.notifications.service.messages.text import truncate + + +def operation_id(document_name: str, search_space_id: int) -> str: + """Build a unique id for a page-limit notification.""" + timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S_%f") + doc_hash = hashlib.md5(document_name.encode()).hexdigest()[:8] + return f"page_limit_{search_space_id}_{timestamp}_{doc_hash}" + + +def summary( + document_name: str, pages_used: int, pages_limit: int, pages_to_add: int +) -> tuple[str, str]: + """Compute the title and message for a blocked-by-page-limit document.""" + display_name = truncate(document_name, 40) + title = f"Page limit exceeded: {display_name}" + message = f"This document has ~{pages_to_add} page(s) but you've used {pages_used}/{pages_limit} pages. Upgrade to process more documents." + return title, message diff --git a/surfsense_backend/app/notifications/types.py b/surfsense_backend/app/notifications/types.py index f2974e584..bb8bcfab1 100644 --- a/surfsense_backend/app/notifications/types.py +++ b/surfsense_backend/app/notifications/types.py @@ -10,8 +10,7 @@ NotificationType = Literal[ "document_processing", "new_mention", "comment_reply", - "insufficient_credits", - "auto_reload_failed", + "page_limit_exceeded", ] NotificationCategory = Literal["comments", "status"] diff --git a/surfsense_backend/app/podcasts/__init__.py b/surfsense_backend/app/podcasts/__init__.py deleted file mode 100644 index 6a152af22..000000000 --- a/surfsense_backend/app/podcasts/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Podcast feature: brief resolution, transcript drafting, and audio rendering. - -Owns the ``podcasts`` table model, which :mod:`app.db` re-exports so existing -``from app.db import Podcast`` imports keep resolving. -""" - -from __future__ import annotations - -__all__: list[str] = [] diff --git a/surfsense_backend/app/podcasts/api/__init__.py b/surfsense_backend/app/podcasts/api/__init__.py deleted file mode 100644 index 4b5b12971..000000000 --- a/surfsense_backend/app/podcasts/api/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""HTTP API for the podcast lifecycle.""" - -from __future__ import annotations - -from .routes import router - -__all__ = ["router"] diff --git a/surfsense_backend/app/podcasts/api/routes.py b/surfsense_backend/app/podcasts/api/routes.py deleted file mode 100644 index 80e5e1c64..000000000 --- a/surfsense_backend/app/podcasts/api/routes.py +++ /dev/null @@ -1,337 +0,0 @@ -"""HTTP surface for the podcast lifecycle. - -Status is observed by the frontend through Zero, so these routes are about -actions (create, edit/approve the brief, regenerate, cancel) and audio delivery. -Each mutating route performs the guarded transition via the service, commits, -then enqueues the matching Celery task; lifecycle errors map to 409/422. -""" - -from __future__ import annotations - -import os -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from pathlib import Path - -from fastapi import APIRouter, Depends, HTTPException, Response -from fastapi.responses import StreamingResponse -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from app.config import config as app_config -from app.db import ( - Permission, - SearchSpace, - SearchSpaceMembership, - User, - get_async_session, -) -from app.podcasts.generation.brief import propose_brief -from app.podcasts.persistence import Podcast, PodcastRepository -from app.podcasts.service import ( - InvalidTransitionError, - PodcastService, - PreconditionFailedError, - SpecConflictError, -) -from app.podcasts.storage import open_audio_stream, purge_audio -from app.podcasts.tasks import draft_transcript_task -from app.podcasts.tts import get_text_to_speech -from app.podcasts.voices import ( - get_voice_catalog, - provider_from_service, - render_voice_preview, -) -from app.users import current_active_user -from app.utils.rbac import check_permission - -from .schemas import ( - CreatePodcastRequest, - PodcastDetail, - PodcastSummary, - UpdateSpecRequest, - VoiceOption, -) - -router = APIRouter() - - -@router.get("/podcasts", response_model=list[PodcastSummary]) -async def list_podcasts( - search_space_id: int | None = None, - skip: int = 0, - limit: int = 100, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - if skip < 0 or limit < 1: - raise HTTPException(status_code=400, detail="Invalid pagination parameters") - - if search_space_id is not None: - await _require(session, user, search_space_id, Permission.PODCASTS_READ) - query = ( - select(Podcast) - .where(Podcast.search_space_id == search_space_id) - .order_by(Podcast.created_at.desc()) - .offset(skip) - .limit(limit) - ) - else: - query = ( - select(Podcast) - .join(SearchSpace) - .join(SearchSpaceMembership) - .where(SearchSpaceMembership.user_id == user.id) - .order_by(Podcast.created_at.desc()) - .offset(skip) - .limit(limit) - ) - result = await session.execute(query) - return list(result.scalars().all()) - - -@router.get("/podcasts/voices", response_model=list[VoiceOption]) -async def list_voices(language: str | None = None): - """Voices the active TTS provider offers, optionally filtered by language.""" - if not app_config.TTS_SERVICE: - raise HTTPException(status_code=503, detail="No TTS provider configured") - - provider = provider_from_service(app_config.TTS_SERVICE) - catalog = get_voice_catalog() - voices = ( - catalog.for_language(provider, language) - if language - else catalog.for_provider(provider) - ) - return [ - VoiceOption( - voice_id=v.voice_id, - display_name=v.display_name, - language=v.language, - gender=v.gender.value, - ) - for v in voices - ] - - -@router.get("/podcasts/voices/{voice_id}/preview") -async def preview_voice( - voice_id: str, - user: User = Depends(current_active_user), -): - """A short audio sample of a voice, so users pick by sound.""" - if not app_config.TTS_SERVICE: - raise HTTPException(status_code=503, detail="No TTS provider configured") - - provider = provider_from_service(app_config.TTS_SERVICE) - try: - voice = get_voice_catalog().get(voice_id) - except KeyError: - raise HTTPException(status_code=404, detail="Unknown voice") from None - if voice.provider is not provider: - raise HTTPException( - status_code=404, detail="Voice not offered by the active TTS provider" - ) - - data, content_type = await render_voice_preview(voice, get_text_to_speech()) - return Response(content=data, media_type=content_type) - - -@router.post("/podcasts", response_model=PodcastDetail, status_code=201) -async def create_podcast( - body: CreatePodcastRequest, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - await _require(session, user, body.search_space_id, Permission.PODCASTS_CREATE) - - service = PodcastService(session) - podcast = await service.create( - title=body.title, - search_space_id=body.search_space_id, - thread_id=body.thread_id, - ) - podcast.source_content = body.source_content - - spec = await propose_brief( - session, - search_space_id=body.search_space_id, - speaker_count=body.speaker_count, - min_minutes=body.min_minutes, - max_minutes=body.max_minutes, - focus=body.focus, - ) - await service.attach_brief(podcast, spec) - await session.commit() - return PodcastDetail.of(podcast) - - -@router.get("/podcasts/{podcast_id}", response_model=PodcastDetail) -async def get_podcast( - podcast_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_READ) - return PodcastDetail.of(podcast) - - -@router.patch("/podcasts/{podcast_id}/spec", response_model=PodcastDetail) -async def update_spec( - podcast_id: int, - body: UpdateSpecRequest, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) - async with _lifecycle_errors(): - await PodcastService(session).update_spec( - podcast, body.spec, body.expected_version - ) - await session.commit() - return PodcastDetail.of(podcast) - - -@router.post("/podcasts/{podcast_id}/brief/approve", response_model=PodcastDetail) -async def approve_brief( - podcast_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """Approve the brief and start drafting the transcript.""" - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) - async with _lifecycle_errors(): - await PodcastService(session).begin_drafting(podcast) - await session.commit() - draft_transcript_task.delay(podcast.id, podcast.search_space_id) - return PodcastDetail.of(podcast) - - -@router.post( - "/podcasts/{podcast_id}/transcript/regenerate", response_model=PodcastDetail -) -async def regenerate_transcript( - podcast_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """Reopen the brief gate for a fresh take; drafting waits for re-approval.""" - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) - async with _lifecycle_errors(): - await PodcastService(session).regenerate(podcast) - await session.commit() - return PodcastDetail.of(podcast) - - -@router.post("/podcasts/{podcast_id}/regenerate/revert", response_model=PodcastDetail) -async def revert_regeneration( - podcast_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """Back out of a regeneration and return to the finished episode.""" - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) - async with _lifecycle_errors(): - await PodcastService(session).revert_regeneration(podcast) - await session.commit() - return PodcastDetail.of(podcast) - - -@router.post("/podcasts/{podcast_id}/cancel", response_model=PodcastDetail) -async def cancel_podcast( - podcast_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) - async with _lifecycle_errors(): - await PodcastService(session).cancel(podcast) - await session.commit() - return PodcastDetail.of(podcast) - - -@router.delete("/podcasts/{podcast_id}", response_model=dict) -async def delete_podcast( - podcast_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_DELETE) - await purge_audio(podcast) - await session.delete(podcast) - await session.commit() - return {"message": "Podcast deleted successfully"} - - -@router.get("/podcasts/{podcast_id}/stream") -async def stream_podcast( - podcast_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_READ) - - if podcast.storage_key: - return StreamingResponse( - open_audio_stream(podcast), - media_type="audio/mpeg", - headers={"Accept-Ranges": "bytes"}, - ) - - # Back-compat: rows rendered before the storage migration kept a local path. - if podcast.file_location and os.path.isfile(podcast.file_location): - path = podcast.file_location - - def iterfile(): - with open(path, mode="rb") as handle: - yield from handle - - return StreamingResponse( - iterfile(), - media_type="audio/mpeg", - headers={ - "Accept-Ranges": "bytes", - "Content-Disposition": f"inline; filename={Path(path).name}", - }, - ) - - raise HTTPException(status_code=404, detail="Podcast audio not found") - - -async def _require( - session: AsyncSession, - user: User, - search_space_id: int, - permission: Permission, -) -> None: - await check_permission( - session, - user, - search_space_id, - permission.value, - "You don't have permission for podcasts in this search space", - ) - - -async def _load( - session: AsyncSession, - user: User, - podcast_id: int, - permission: Permission, -) -> Podcast: - podcast = await PodcastRepository(session).get(podcast_id) - if podcast is None: - raise HTTPException(status_code=404, detail="Podcast not found") - await _require(session, user, podcast.search_space_id, permission) - return podcast - - -@asynccontextmanager -async def _lifecycle_errors() -> AsyncIterator[None]: - """Map service lifecycle errors onto HTTP responses.""" - try: - yield - except (SpecConflictError, InvalidTransitionError) as exc: - raise HTTPException(status_code=409, detail=str(exc)) from exc - except PreconditionFailedError as exc: - raise HTTPException(status_code=422, detail=str(exc)) from exc diff --git a/surfsense_backend/app/podcasts/api/schemas.py b/surfsense_backend/app/podcasts/api/schemas.py deleted file mode 100644 index 7f1f8cc7c..000000000 --- a/surfsense_backend/app/podcasts/api/schemas.py +++ /dev/null @@ -1,97 +0,0 @@ -"""Request and response shapes for the podcast API. - -Read models surface the lifecycle state the frontend can't derive from Zero (the -deserialized brief and transcript); the action requests carry just what each -guarded transition needs. -""" - -from __future__ import annotations - -from datetime import datetime - -from pydantic import BaseModel, ConfigDict, Field - -from app.podcasts.persistence import Podcast, PodcastStatus -from app.podcasts.schemas import PodcastSpec, Transcript -from app.podcasts.service import has_stored_episode, read_spec, read_transcript - -# Defaults applied when a create request omits brief sizing; the brief gate lets -# the user adjust before any cost is incurred. -DEFAULT_SPEAKER_COUNT = 2 -DEFAULT_MIN_MINUTES = 10 -DEFAULT_MAX_MINUTES = 20 - - -class CreatePodcastRequest(BaseModel): - """Create a podcast and kick off brief proposal.""" - - title: str = Field(..., min_length=1, max_length=500) - search_space_id: int - source_content: str = Field(..., min_length=1) - thread_id: int | None = None - speaker_count: int = Field(default=DEFAULT_SPEAKER_COUNT, ge=1, le=6) - min_minutes: int = Field(default=DEFAULT_MIN_MINUTES, ge=1) - max_minutes: int = Field(default=DEFAULT_MAX_MINUTES, ge=1) - focus: str | None = Field(default=None, max_length=2000) - - -class UpdateSpecRequest(BaseModel): - """Replace the brief at the gate, guarded by the expected version.""" - - spec: PodcastSpec - expected_version: int = Field(..., ge=1) - - -class VoiceOption(BaseModel): - """One selectable voice surfaced to the brief editor.""" - - voice_id: str - display_name: str - language: str - gender: str - - -class PodcastSummary(BaseModel): - """Lightweight list item.""" - - model_config = ConfigDict(from_attributes=True) - - id: int - title: str - status: PodcastStatus - created_at: datetime - search_space_id: int - - -class PodcastDetail(BaseModel): - """Full podcast state for the detail view and action responses.""" - - id: int - title: str - status: PodcastStatus - spec_version: int - spec: PodcastSpec | None - transcript: Transcript | None - has_audio: bool - duration_seconds: int | None - error: str | None - created_at: datetime - search_space_id: int - thread_id: int | None - - @classmethod - def of(cls, podcast: Podcast) -> PodcastDetail: - return cls( - id=podcast.id, - title=podcast.title, - status=PodcastStatus(podcast.status), - spec_version=podcast.spec_version, - spec=read_spec(podcast), - transcript=read_transcript(podcast), - has_audio=has_stored_episode(podcast), - duration_seconds=podcast.duration_seconds, - error=podcast.error, - created_at=podcast.created_at, - search_space_id=podcast.search_space_id, - thread_id=podcast.thread_id, - ) diff --git a/surfsense_backend/app/podcasts/generation/__init__.py b/surfsense_backend/app/podcasts/generation/__init__.py deleted file mode 100644 index a30b8f9af..000000000 --- a/surfsense_backend/app/podcasts/generation/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Generation: the controlled graphs that produce a brief and a transcript. - -``brief`` proposes a reviewable spec from deterministic defaults; ``transcript`` -is the LLM-driven step, drafting long-form dialogue outline-first. -""" - -from __future__ import annotations - -from .brief import BriefConfig, BriefState, build_brief_graph -from .transcript import TranscriptConfig, TranscriptState, build_transcript_graph - -__all__ = [ - "BriefConfig", - "BriefState", - "TranscriptConfig", - "TranscriptState", - "build_brief_graph", - "build_transcript_graph", -] diff --git a/surfsense_backend/app/podcasts/generation/brief/__init__.py b/surfsense_backend/app/podcasts/generation/brief/__init__.py deleted file mode 100644 index 5083c4708..000000000 --- a/surfsense_backend/app/podcasts/generation/brief/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Brief planning: propose a reviewable spec from last-used preferences.""" - -from __future__ import annotations - -from .config import BriefConfig -from .graph import build_brief_graph -from .propose import propose_brief -from .state import BriefState - -__all__ = ["BriefConfig", "BriefState", "build_brief_graph", "propose_brief"] diff --git a/surfsense_backend/app/podcasts/generation/brief/config.py b/surfsense_backend/app/podcasts/generation/brief/config.py deleted file mode 100644 index 4f92585ae..000000000 --- a/surfsense_backend/app/podcasts/generation/brief/config.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Configurable inputs for the brief-planning graph.""" - -from __future__ import annotations - -from dataclasses import dataclass, field, fields - -from langchain_core.runnables import RunnableConfig - -# Sensible defaults for a fresh brief; the user adjusts the range at the gate. -DEFAULT_SPEAKER_COUNT = 2 -DEFAULT_MIN_MINUTES = 10 -DEFAULT_MAX_MINUTES = 20 - - -@dataclass(kw_only=True) -class BriefConfig: - """Signals used to propose a brief; everything here is non-LLM context.""" - - speaker_count: int = DEFAULT_SPEAKER_COUNT - min_minutes: int = DEFAULT_MIN_MINUTES - max_minutes: int = DEFAULT_MAX_MINUTES - focus: str | None = None - last_used_language: str | None = None - last_used_voices: list[str] = field(default_factory=list) - - @classmethod - def from_runnable_config(cls, config: RunnableConfig | None = None) -> BriefConfig: - configurable = (config.get("configurable") or {}) if config else {} - names = {f.name for f in fields(cls) if f.init} - return cls(**{k: v for k, v in configurable.items() if k in names}) diff --git a/surfsense_backend/app/podcasts/generation/brief/graph.py b/surfsense_backend/app/podcasts/generation/brief/graph.py deleted file mode 100644 index a643bdbb4..000000000 --- a/surfsense_backend/app/podcasts/generation/brief/graph.py +++ /dev/null @@ -1,25 +0,0 @@ -"""The brief-planning graph: propose a reviewable spec from defaults.""" - -from __future__ import annotations - -from langgraph.graph import StateGraph - -from .config import BriefConfig -from .nodes import propose_spec -from .state import BriefState - - -def build_brief_graph(): - workflow = StateGraph(BriefState, config_schema=BriefConfig) - - workflow.add_node("propose_spec", propose_spec) - - workflow.add_edge("__start__", "propose_spec") - workflow.add_edge("propose_spec", "__end__") - - graph = workflow.compile() - graph.name = "Surfsense Podcast Brief" - return graph - - -graph = build_brief_graph() diff --git a/surfsense_backend/app/podcasts/generation/brief/nodes.py b/surfsense_backend/app/podcasts/generation/brief/nodes.py deleted file mode 100644 index c0a6f1ae1..000000000 --- a/surfsense_backend/app/podcasts/generation/brief/nodes.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Brief-planning node: propose a full spec from deterministic defaults. - -``propose_spec`` is pure resolution — it never spends tokens. It reuses the -user's last-used language/voices when available and otherwise falls back to -English, so the brief gate opens pre-filled and the common case needs no edits. -""" - -from __future__ import annotations - -from typing import Any - -from langchain_core.runnables import RunnableConfig - -from app.config import config as app_config -from app.podcasts.resolution import ( - DEFAULT_LANGUAGE, - LanguageContext, - resolve_language, - resolve_voices, -) -from app.podcasts.schemas import ( - DurationTarget, - PodcastSpec, - PodcastStyle, - SpeakerRole, - SpeakerSpec, - normalize_language_tag, -) -from app.podcasts.voices import ( - TtsProvider, - VoiceCatalog, - get_voice_catalog, - provider_from_service, -) - -from .config import BriefConfig -from .state import BriefState - -# Default role per speaker slot; extra speakers beyond the list fall back to guest. -_ROLE_BY_SLOT = ( - SpeakerRole.HOST, - SpeakerRole.GUEST, - SpeakerRole.EXPERT, - SpeakerRole.COHOST, - SpeakerRole.NARRATOR, -) - - -def propose_spec(state: BriefState, config: RunnableConfig) -> dict[str, Any]: - """Build a complete :class:`PodcastSpec` from the resolved defaults.""" - brief = BriefConfig.from_runnable_config(config) - provider = _active_provider() - catalog = get_voice_catalog() - - language = _supported_language( - last_used=brief.last_used_language, - provider=provider, - catalog=catalog, - ) - voices = resolve_voices( - catalog=catalog, - provider=provider, - language=language, - speaker_count=brief.speaker_count, - preferred=brief.last_used_voices, - ) - - speakers = [ - SpeakerSpec( - slot=slot, - name=_default_name(slot), - role=_role_for(slot), - voice_id=voice.voice_id, - ) - for slot, voice in enumerate(voices) - ] - spec = PodcastSpec( - language=language, - style=PodcastStyle.CONVERSATIONAL, - speakers=speakers, - duration=DurationTarget( - min_minutes=brief.min_minutes, max_minutes=brief.max_minutes - ), - focus=brief.focus, - ) - return {"spec": spec} - - -def _active_provider() -> TtsProvider: - service = app_config.TTS_SERVICE - if not service: - raise ValueError("TTS_SERVICE is not configured") - return provider_from_service(service) - - -def _supported_language( - *, - last_used: str | None, - provider: TtsProvider, - catalog: VoiceCatalog, -) -> str: - raw = resolve_language(LanguageContext(last_used=last_used)) - try: - language = normalize_language_tag(raw) - except ValueError: - language = DEFAULT_LANGUAGE - if not catalog.supports_language(provider, language): - return DEFAULT_LANGUAGE - return language - - -def _role_for(slot: int) -> SpeakerRole: - return _ROLE_BY_SLOT[slot] if slot < len(_ROLE_BY_SLOT) else SpeakerRole.GUEST - - -def _default_name(slot: int) -> str: - role = _role_for(slot) - label = role.value.replace("cohost", "co-host").title() - return label if slot < len(_ROLE_BY_SLOT) else f"{label} {slot}" diff --git a/surfsense_backend/app/podcasts/generation/brief/propose.py b/surfsense_backend/app/podcasts/generation/brief/propose.py deleted file mode 100644 index 17344702b..000000000 --- a/surfsense_backend/app/podcasts/generation/brief/propose.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Propose a podcast's initial brief spec.""" - -from __future__ import annotations - -from sqlalchemy.ext.asyncio import AsyncSession - -from app.podcasts.persistence import PodcastRepository -from app.podcasts.schemas import PodcastSpec -from app.podcasts.service import preferences_from - -from .config import DEFAULT_MAX_MINUTES, DEFAULT_MIN_MINUTES, DEFAULT_SPEAKER_COUNT -from .graph import graph as brief_graph -from .state import BriefState - - -async def propose_brief( - session: AsyncSession, - *, - search_space_id: int, - speaker_count: int = DEFAULT_SPEAKER_COUNT, - min_minutes: int = DEFAULT_MIN_MINUTES, - max_minutes: int = DEFAULT_MAX_MINUTES, - focus: str | None = None, -) -> PodcastSpec: - """Reuse the last-used language and voices, else English; return the spec.""" - last_language, last_voices = preferences_from( - await PodcastRepository(session).latest_with_spec(search_space_id) - ) - config = { - "configurable": { - "speaker_count": speaker_count, - "min_minutes": min_minutes, - "max_minutes": max_minutes, - "focus": focus, - "last_used_language": last_language, - "last_used_voices": last_voices, - } - } - result = await brief_graph.ainvoke(BriefState(), config=config) - return result["spec"] diff --git a/surfsense_backend/app/podcasts/generation/brief/state.py b/surfsense_backend/app/podcasts/generation/brief/state.py deleted file mode 100644 index 418fb6fa9..000000000 --- a/surfsense_backend/app/podcasts/generation/brief/state.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Mutable state threaded through the brief-planning graph.""" - -from __future__ import annotations - -from dataclasses import dataclass - -from app.podcasts.schemas import PodcastSpec - - -@dataclass -class BriefState: - """The proposed spec the graph produces; inputs arrive via the config.""" - - spec: PodcastSpec | None = None diff --git a/surfsense_backend/app/podcasts/generation/prompts/__init__.py b/surfsense_backend/app/podcasts/generation/prompts/__init__.py deleted file mode 100644 index 041dd4e6d..000000000 --- a/surfsense_backend/app/podcasts/generation/prompts/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Prompt builders for the generation graphs.""" - -from __future__ import annotations - -from .draft_segment import draft_segment_prompt -from .plan_outline import plan_outline_prompt -from .speakers import render_speaker_roster - -__all__ = [ - "draft_segment_prompt", - "plan_outline_prompt", - "render_speaker_roster", -] diff --git a/surfsense_backend/app/podcasts/generation/prompts/draft_segment.py b/surfsense_backend/app/podcasts/generation/prompts/draft_segment.py deleted file mode 100644 index c81dfa385..000000000 --- a/surfsense_backend/app/podcasts/generation/prompts/draft_segment.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Prompt for drafting one outline segment into dialogue turns. - -Each segment is drafted on its own so long episodes stay coherent and within -context limits. A short recap of the preceding dialogue is passed in so the new -segment continues naturally instead of restarting. The model must write in the -episode language and attribute every line to a real speaker slot. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from app.podcasts.schemas import PodcastSpec - -from .speakers import render_speaker_roster - -if TYPE_CHECKING: - from app.podcasts.generation.transcript.planning import OutlineSegment - - -def draft_segment_prompt( - *, - spec: PodcastSpec, - segment: OutlineSegment, - position: int, - total: int, - recap: str | None, -) -> str: - talking_points = "\n".join(f"- {point}" for point in segment.talking_points) - recap_block = ( - f"\nRecap of the conversation so far (continue from here, do not repeat " - f"it):\n{recap}\n" - if recap - else "\nThis is the opening segment; begin the conversation naturally.\n" - ) - return f"""\ -You are scripting natural, engaging podcast dialogue for segment {position} of \ -{total}. - -Write entirely in {spec.language}. The format is {spec.style.value}. -Speakers — attribute every line using these exact slot numbers: -{render_speaker_roster(spec)} -{recap_block} -This segment is "{segment.title}". Cover these points using only facts grounded \ -in the provided source content: -{talking_points} - -Aim for about {segment.target_words} words of dialogue. Keep turns conversational \ -and varied; speakers should react to each other rather than deliver monologues. \ -Do not add greetings or sign-offs unless this is the first or last segment. - -Respond with strict JSON and nothing else: -{{"turns": [{{"speaker": <slot>, "text": "..."}}]}} -""" diff --git a/surfsense_backend/app/podcasts/generation/prompts/plan_outline.py b/surfsense_backend/app/podcasts/generation/prompts/plan_outline.py deleted file mode 100644 index 1b227c2ff..000000000 --- a/surfsense_backend/app/podcasts/generation/prompts/plan_outline.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Prompt for planning a long-form podcast outline before drafting dialogue. - -Outlining first is what makes long-form reliable: a single LLM call cannot hold -a coherent one- to two-hour script, but it can plan segments that are then -drafted independently against a shared plan. The prompt is told the target -length so the number and size of segments scale with the requested duration. -""" - -from __future__ import annotations - -from app.podcasts.schemas import PodcastSpec - -from .speakers import render_speaker_roster - - -def plan_outline_prompt( - *, - spec: PodcastSpec, - target_words: int, - suggested_segments: int, - focus: str | None, -) -> str: - focus_block = ( - f"\nThe user asked the episode to focus on:\n{focus}\n" if focus else "" - ) - return f"""\ -You are a podcast showrunner planning the structure of an episode before any \ -dialogue is written. - -The episode language is {spec.language}. The format is {spec.style.value}. -Speakers (refer to them by these slots later): -{render_speaker_roster(spec)} -{focus_block} -Plan an outline that, when fully drafted, reaches roughly {target_words} words \ -of spoken dialogue (about {suggested_segments} segments). Each segment is one \ -coherent beat of the conversation: an opening, distinct topic areas grounded in \ -the source content, and a closing. - -For each segment provide: -- title: a short label for the beat -- talking_points: 2-5 concrete points to cover, drawn from the source content -- target_words: how many words of dialogue this segment should run (the sum \ -across segments should approximate {target_words}) - -Respond with strict JSON and nothing else: -{{"segments": [{{"title": "...", "talking_points": ["..."], "target_words": 0}}]}} -""" diff --git a/surfsense_backend/app/podcasts/generation/prompts/speakers.py b/surfsense_backend/app/podcasts/generation/prompts/speakers.py deleted file mode 100644 index 9df4138df..000000000 --- a/surfsense_backend/app/podcasts/generation/prompts/speakers.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Render a spec's speaker roster for prompts. - -The drafting prompts must reference speakers by the exact ``slot`` the renderer -expects, so this is the single place that formats that roster — keeping the -slot contract identical across every prompt that mentions speakers. -""" - -from __future__ import annotations - -from app.podcasts.schemas import PodcastSpec - - -def render_speaker_roster(spec: PodcastSpec) -> str: - lines = [ - f"- slot {speaker.slot} — {speaker.name} (role: {speaker.role.value})" - for speaker in spec.speakers - ] - return "\n".join(lines) diff --git a/surfsense_backend/app/podcasts/generation/structured.py b/surfsense_backend/app/podcasts/generation/structured.py deleted file mode 100644 index 08132e776..000000000 --- a/surfsense_backend/app/podcasts/generation/structured.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Parse a model's reply into a Pydantic shape, tolerating chatty output. - -Agent LLMs return JSON wrapped in prose, markdown fences, or reasoning blocks, -so a plain ``model_validate_json`` is unreliable. Centralising the tolerant -parse here keeps every generation node validating replies the same way. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, TypeVar - -from pydantic import BaseModel, ValidationError - -from app.utils.content_utils import extract_text_content, strip_markdown_fences - -if TYPE_CHECKING: - from langchain_core.messages import BaseMessage - -T = TypeVar("T", bound=BaseModel) - - -class StructuredOutputError(RuntimeError): - """The model reply could not be parsed into the expected shape.""" - - -async def invoke_json[T: BaseModel]( - llm, messages: list[BaseMessage], model: type[T] -) -> T: - """Invoke ``llm`` and validate its reply as ``model``.""" - response = await llm.ainvoke(messages) - content = strip_markdown_fences(extract_text_content(response.content)) - - try: - return model.model_validate_json(content) - except (ValidationError, ValueError): - pass - - start = content.find("{") - end = content.rfind("}") + 1 - if 0 <= start < end: - try: - return model.model_validate_json(content[start:end]) - except (ValidationError, ValueError) as exc: - raise StructuredOutputError( - f"could not parse {model.__name__} from model reply" - ) from exc - - raise StructuredOutputError( - f"no JSON object found for {model.__name__} in model reply" - ) diff --git a/surfsense_backend/app/podcasts/generation/transcript/__init__.py b/surfsense_backend/app/podcasts/generation/transcript/__init__.py deleted file mode 100644 index 5c8f23cd7..000000000 --- a/surfsense_backend/app/podcasts/generation/transcript/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Transcript drafting: outline-first, long-form dialogue generation.""" - -from __future__ import annotations - -from .config import TranscriptConfig -from .graph import build_transcript_graph -from .planning import Outline, OutlineSegment, SegmentDraft -from .state import TranscriptState - -__all__ = [ - "Outline", - "OutlineSegment", - "SegmentDraft", - "TranscriptConfig", - "TranscriptState", - "build_transcript_graph", -] diff --git a/surfsense_backend/app/podcasts/generation/transcript/config.py b/surfsense_backend/app/podcasts/generation/transcript/config.py deleted file mode 100644 index f627fc166..000000000 --- a/surfsense_backend/app/podcasts/generation/transcript/config.py +++ /dev/null @@ -1,26 +0,0 @@ -"""Configurable inputs for the transcript-drafting graph.""" - -from __future__ import annotations - -from dataclasses import dataclass, fields - -from langchain_core.runnables import RunnableConfig - -from app.podcasts.schemas import PodcastSpec - - -@dataclass(kw_only=True) -class TranscriptConfig: - """The approved spec and user focus that drive drafting.""" - - search_space_id: int - spec: PodcastSpec - focus: str | None = None - - @classmethod - def from_runnable_config( - cls, config: RunnableConfig | None = None - ) -> TranscriptConfig: - configurable = (config.get("configurable") or {}) if config else {} - names = {f.name for f in fields(cls) if f.init} - return cls(**{k: v for k, v in configurable.items() if k in names}) diff --git a/surfsense_backend/app/podcasts/generation/transcript/graph.py b/surfsense_backend/app/podcasts/generation/transcript/graph.py deleted file mode 100644 index 2f97db50f..000000000 --- a/surfsense_backend/app/podcasts/generation/transcript/graph.py +++ /dev/null @@ -1,29 +0,0 @@ -"""The transcript-drafting graph: outline, draft segments, finalize.""" - -from __future__ import annotations - -from langgraph.graph import StateGraph - -from .config import TranscriptConfig -from .nodes import draft_segments, finalize, plan_outline -from .state import TranscriptState - - -def build_transcript_graph(): - workflow = StateGraph(TranscriptState, config_schema=TranscriptConfig) - - workflow.add_node("plan_outline", plan_outline) - workflow.add_node("draft_segments", draft_segments) - workflow.add_node("finalize", finalize) - - workflow.add_edge("__start__", "plan_outline") - workflow.add_edge("plan_outline", "draft_segments") - workflow.add_edge("draft_segments", "finalize") - workflow.add_edge("finalize", "__end__") - - graph = workflow.compile() - graph.name = "Surfsense Podcast Transcript" - return graph - - -graph = build_transcript_graph() diff --git a/surfsense_backend/app/podcasts/generation/transcript/nodes.py b/surfsense_backend/app/podcasts/generation/transcript/nodes.py deleted file mode 100644 index 44d6b219d..000000000 --- a/surfsense_backend/app/podcasts/generation/transcript/nodes.py +++ /dev/null @@ -1,127 +0,0 @@ -"""Transcript-drafting nodes: plan an outline, draft each beat, then assemble. - -Long-form is produced beat-by-beat: a single call plans the structure, then each -segment is drafted on its own with a recap of what came before so the script -stays coherent without holding the whole episode in one context window. -""" - -from __future__ import annotations - -from typing import Any - -from langchain_core.messages import HumanMessage, SystemMessage -from langchain_core.runnables import RunnableConfig - -from app.podcasts.schemas import PodcastSpec, Transcript, TranscriptTurn -from app.services.llm_service import get_agent_llm - -from ..prompts import draft_segment_prompt, plan_outline_prompt -from ..structured import invoke_json -from .config import TranscriptConfig -from .planning import Outline, SegmentDraft -from .state import TranscriptState - -# Average speaking rate; converts target minutes to a target word count. -_WORDS_PER_MINUTE = 150 -# Rough words per outline segment, used to suggest how many segments to plan. -_WORDS_PER_SEGMENT = 250 -# Cap on source text sent per LLM call to bound tokens on large sources. -_SOURCE_BUDGET_CHARS = 12000 -# How much prior dialogue to recap into each segment for continuity. -_RECAP_CHARS = 800 - - -async def plan_outline( - state: TranscriptState, config: RunnableConfig -) -> dict[str, Any]: - """Plan the segment structure sized to the spec's target duration.""" - tc = TranscriptConfig.from_runnable_config(config) - llm = await _require_llm(state, tc) - - target_words = round(tc.spec.duration.midpoint_minutes * _WORDS_PER_MINUTE) - suggested_segments = max(1, round(target_words / _WORDS_PER_SEGMENT)) - - messages = [ - SystemMessage( - content=plan_outline_prompt( - spec=tc.spec, - target_words=target_words, - suggested_segments=suggested_segments, - focus=tc.focus, - ) - ), - HumanMessage(content=_source_block(state.source_content)), - ] - outline = await invoke_json(llm, messages, Outline) - return {"outline": outline} - - -async def draft_segments( - state: TranscriptState, config: RunnableConfig -) -> dict[str, Any]: - """Draft each outline segment in order, carrying a running recap.""" - tc = TranscriptConfig.from_runnable_config(config) - llm = await _require_llm(state, tc) - outline = state.outline - if outline is None: - raise RuntimeError("draft_segments requires an outline") - - source_block = _source_block(state.source_content) - turns: list[TranscriptTurn] = [] - total = len(outline.segments) - - for index, segment in enumerate(outline.segments): - messages = [ - SystemMessage( - content=draft_segment_prompt( - spec=tc.spec, - segment=segment, - position=index + 1, - total=total, - recap=_recap(turns, tc.spec), - ) - ), - HumanMessage(content=source_block), - ] - draft = await invoke_json(llm, messages, SegmentDraft) - turns.extend(_valid_turns(draft, tc.spec)) - - return {"drafted_turns": turns} - - -def finalize(state: TranscriptState, config: RunnableConfig) -> dict[str, Any]: - """Assemble drafted turns into a validated transcript.""" - if not state.drafted_turns: - raise RuntimeError("drafting produced no usable dialogue") - return {"transcript": Transcript(turns=state.drafted_turns)} - - -async def _require_llm(state: TranscriptState, tc: TranscriptConfig): - llm = await get_agent_llm(state.db_session, tc.search_space_id) - if llm is None: - raise RuntimeError( - f"no agent LLM configured for search space {tc.search_space_id}" - ) - return llm - - -def _source_block(source_content: str) -> str: - sample = (source_content or "")[:_SOURCE_BUDGET_CHARS] - return f"<source_content>{sample}</source_content>" - - -def _valid_turns(draft: SegmentDraft, spec: PodcastSpec) -> list[TranscriptTurn]: - # Drop any turn the model attributed to a slot the spec doesn't define, so a - # stray attribution can't break rendering downstream. - valid_slots = {speaker.slot for speaker in spec.speakers} - return [turn for turn in draft.turns if turn.speaker in valid_slots] - - -def _recap(turns: list[TranscriptTurn], spec: PodcastSpec) -> str | None: - if not turns: - return None - names = {speaker.slot: speaker.name for speaker in spec.speakers} - rendered = "\n".join( - f"{names.get(turn.speaker, turn.speaker)}: {turn.text}" for turn in turns - ) - return rendered[-_RECAP_CHARS:] diff --git a/surfsense_backend/app/podcasts/generation/transcript/planning.py b/surfsense_backend/app/podcasts/generation/transcript/planning.py deleted file mode 100644 index 3f6aeac9b..000000000 --- a/surfsense_backend/app/podcasts/generation/transcript/planning.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Internal shapes the transcript graph passes between its nodes. - -These are generation-time artifacts (the outline and per-segment drafts), not -persisted or API-facing. Segment drafts reuse :class:`TranscriptTurn` so the -speaker-slot contract and turn validation are identical to the final transcript. -""" - -from __future__ import annotations - -from pydantic import BaseModel, Field - -from app.podcasts.schemas import TranscriptTurn - - -class OutlineSegment(BaseModel): - """One planned beat of the conversation, drafted independently.""" - - title: str = Field(..., min_length=1) - talking_points: list[str] = Field(default_factory=list) - target_words: int = Field(..., ge=1) - - -class Outline(BaseModel): - """The full plan: ordered segments sized to the target duration.""" - - segments: list[OutlineSegment] = Field(..., min_length=1) - - -class SegmentDraft(BaseModel): - """The dialogue a single segment produced.""" - - turns: list[TranscriptTurn] = Field(default_factory=list) diff --git a/surfsense_backend/app/podcasts/generation/transcript/state.py b/surfsense_backend/app/podcasts/generation/transcript/state.py deleted file mode 100644 index f11337471..000000000 --- a/surfsense_backend/app/podcasts/generation/transcript/state.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Mutable state threaded through the transcript-drafting graph.""" - -from __future__ import annotations - -from dataclasses import dataclass, field - -from sqlalchemy.ext.asyncio import AsyncSession - -from app.podcasts.schemas import Transcript, TranscriptTurn - -from .planning import Outline - - -@dataclass -class TranscriptState: - """Source content plus the intermediate and final drafting artifacts.""" - - db_session: AsyncSession - source_content: str - outline: Outline | None = None - drafted_turns: list[TranscriptTurn] = field(default_factory=list) - transcript: Transcript | None = None diff --git a/surfsense_backend/app/podcasts/persistence/__init__.py b/surfsense_backend/app/podcasts/persistence/__init__.py deleted file mode 100644 index 2166d5d9d..000000000 --- a/surfsense_backend/app/podcasts/persistence/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Models, enums, and data access for the podcasts table.""" - -from __future__ import annotations - -from .enums import PodcastStatus -from .models import Podcast -from .repository import PodcastRepository - -__all__ = ["Podcast", "PodcastRepository", "PodcastStatus"] diff --git a/surfsense_backend/app/podcasts/persistence/enums/__init__.py b/surfsense_backend/app/podcasts/persistence/enums/__init__.py deleted file mode 100644 index f0527fd78..000000000 --- a/surfsense_backend/app/podcasts/persistence/enums/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Enums for the podcasts table.""" - -from __future__ import annotations - -from .podcast_status import PodcastStatus - -__all__ = ["PodcastStatus"] diff --git a/surfsense_backend/app/podcasts/persistence/enums/podcast_status.py b/surfsense_backend/app/podcasts/persistence/enums/podcast_status.py deleted file mode 100644 index 28f29afb5..000000000 --- a/surfsense_backend/app/podcasts/persistence/enums/podcast_status.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Podcast generation lifecycle. - -The status drives a guarded state machine. A podcast is proposed (``PENDING``), -gets a reviewable brief (``AWAITING_BRIEF``), is drafted into a transcript -(``DRAFTING``), then rendered to audio (``RENDERING`` → ``READY``). ``FAILED`` -and ``CANCELLED`` are terminal; a ``READY`` episode can be sent back to the -brief gate for regeneration, and an in-flight regeneration can be reverted to -``READY`` while the previous audio still exists. ``AWAITING_REVIEW`` is -retained for legacy rows but -never entered anymore — the brief is the only approval gate. The Python enum is -kept in lockstep with the ``podcast_status`` Postgres type via its paired -migration. -""" - -from __future__ import annotations - -from enum import StrEnum - - -class PodcastStatus(StrEnum): - PENDING = "pending" - AWAITING_BRIEF = "awaiting_brief" - DRAFTING = "drafting" - AWAITING_REVIEW = "awaiting_review" - RENDERING = "rendering" - READY = "ready" - FAILED = "failed" - CANCELLED = "cancelled" - - @property - def is_terminal(self) -> bool: - """Whether no further transition is possible from this state.""" - return self in _TERMINAL - - @property - def is_gate(self) -> bool: - """Whether this state waits on user input before proceeding.""" - return self in _GATES - - -_TERMINAL = frozenset({PodcastStatus.FAILED, PodcastStatus.CANCELLED}) -_GATES = frozenset({PodcastStatus.AWAITING_BRIEF, PodcastStatus.AWAITING_REVIEW}) diff --git a/surfsense_backend/app/podcasts/persistence/models.py b/surfsense_backend/app/podcasts/persistence/models.py deleted file mode 100644 index 6e40a8040..000000000 --- a/surfsense_backend/app/podcasts/persistence/models.py +++ /dev/null @@ -1,82 +0,0 @@ -"""``podcasts`` table: a generated podcast, its brief, transcript, and state.""" - -from __future__ import annotations - -from sqlalchemy import ( - Column, - Enum as SQLAlchemyEnum, - ForeignKey, - Integer, - String, - Text, -) -from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import relationship - -from app.db import BaseModel, TimestampMixin - -from .enums import PodcastStatus - - -class Podcast(BaseModel, TimestampMixin): - """A podcast across its whole lifecycle: brief, transcript, audio, status. - - ``spec`` (the reviewable brief) and ``podcast_transcript`` are JSONB so the - flexible Pydantic shapes can evolve without migrations. ``spec_version`` - backs optimistic concurrency on brief edits. Rendered audio lives in the - object store, addressed by ``storage_backend`` + ``storage_key`` rather than - a raw path. - """ - - __tablename__ = "podcasts" - - title = Column(String(500), nullable=False) - - status = Column( - SQLAlchemyEnum( - PodcastStatus, - name="podcast_status", - create_type=False, - values_callable=lambda x: [e.value for e in x], - ), - nullable=False, - default=PodcastStatus.PENDING, - server_default=PodcastStatus.PENDING.value, - index=True, - ) - - # The source material the episode is generated from. Persisted because - # drafting happens after the brief gate, long after creation. - source_content = Column(Text, nullable=True) - - # The reviewable brief (PodcastSpec); null until the brief gate is reached. - spec = Column(JSONB, nullable=True) - # Bumped on every spec edit; guards concurrent edits at the brief gate. - spec_version = Column(Integer, nullable=False, default=1, server_default="1") - - # The drafted dialogue (Transcript); null until drafting completes. - podcast_transcript = Column(JSONB, nullable=True) - - # Where the rendered audio lives in the object store; null until READY. - storage_backend = Column(String(32), nullable=True) - storage_key = Column(Text, nullable=True) - duration_seconds = Column(Integer, nullable=True) - - # Human-readable reason when status is FAILED. - error = Column(Text, nullable=True) - - # Legacy local audio path; retained for back-compat until cutover. - file_location = Column(Text, nullable=True) - - search_space_id = Column( - Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False - ) - search_space = relationship("SearchSpace", back_populates="podcasts") - - thread_id = Column( - Integer, - ForeignKey("new_chat_threads.id", ondelete="SET NULL"), - nullable=True, - index=True, - ) - thread = relationship("NewChatThread") diff --git a/surfsense_backend/app/podcasts/persistence/repository.py b/surfsense_backend/app/podcasts/persistence/repository.py deleted file mode 100644 index 04eae9ce1..000000000 --- a/surfsense_backend/app/podcasts/persistence/repository.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Data access for the ``podcasts`` table. - -A thin async repository so the service and tasks never write raw queries. It -only loads and persists rows; lifecycle rules and (de)serialization live in the -service. -""" - -from __future__ import annotations - -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from .models import Podcast - - -class PodcastRepository: - """Loads and stores :class:`Podcast` rows for one session.""" - - def __init__(self, session: AsyncSession) -> None: - self._session = session - - async def get(self, podcast_id: int) -> Podcast | None: - return await self._session.get(Podcast, podcast_id) - - async def add(self, podcast: Podcast) -> Podcast: - """Persist a new row and assign its primary key.""" - self._session.add(podcast) - await self._session.flush() - return podcast - - async def latest_with_spec(self, search_space_id: int) -> Podcast | None: - """Most recent podcast in the space that has a stored brief. - - Used to seed language/voice defaults for a new podcast from what the - user chose last. - """ - result = await self._session.execute( - select(Podcast) - .where( - Podcast.search_space_id == search_space_id, - Podcast.spec.is_not(None), - ) - .order_by(Podcast.created_at.desc()) - .limit(1) - ) - return result.scalars().first() diff --git a/surfsense_backend/app/podcasts/rendering/__init__.py b/surfsense_backend/app/podcasts/rendering/__init__.py deleted file mode 100644 index 9fb50a2e1..000000000 --- a/surfsense_backend/app/podcasts/rendering/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Rendering: synthesise and merge an approved transcript into audio. - -The :class:`PodcastRenderer` is the public entry point; the segment cache and -FFmpeg merge are implementation details it owns. -""" - -from __future__ import annotations - -from .errors import RenderError -from .renderer import PodcastRenderer, RenderedPodcast - -__all__ = ["PodcastRenderer", "RenderError", "RenderedPodcast"] diff --git a/surfsense_backend/app/podcasts/rendering/cache.py b/surfsense_backend/app/podcasts/rendering/cache.py deleted file mode 100644 index 32d9f0c21..000000000 --- a/surfsense_backend/app/podcasts/rendering/cache.py +++ /dev/null @@ -1,53 +0,0 @@ -"""Content-addressed cache for synthesised segments. - -Each segment's audio is keyed by everything that determines its bytes (voice, -language, speed, text). Keeping the cache in a stable per-podcast directory -makes re-renders cheap: changing one speaker's voice only misses that speaker's -turns, and a worker restart mid-render resumes from whatever was already -written. The key intentionally excludes the segment's position so identical -lines (e.g. repeated "Right.") synthesise once. -""" - -from __future__ import annotations - -import hashlib -import json -from pathlib import Path - -from app.podcasts.tts import SynthesisRequest - - -class SegmentCache: - """On-disk store of segment audio, addressed by request content hash.""" - - def __init__(self, root: Path) -> None: - self._root = root - self._root.mkdir(parents=True, exist_ok=True) - - def key(self, request: SynthesisRequest) -> str: - """A stable hash of the inputs that determine the synthesised bytes.""" - material = json.dumps( - { - "voice": request.voice, - "language": request.language, - "speed": request.speed, - "text": request.text, - }, - sort_keys=True, - ensure_ascii=True, - ) - return hashlib.sha256(material.encode("utf-8")).hexdigest() - - def path(self, key: str, container: str) -> Path: - return self._root / f"{key}.{container}" - - def get(self, key: str, container: str) -> Path | None: - """Return the cached segment path, or ``None`` on a miss.""" - path = self.path(key, container) - return path if path.exists() else None - - def put(self, key: str, container: str, data: bytes) -> Path: - """Write ``data`` for ``key`` and return its path.""" - path = self.path(key, container) - path.write_bytes(data) - return path diff --git a/surfsense_backend/app/podcasts/rendering/errors.py b/surfsense_backend/app/podcasts/rendering/errors.py deleted file mode 100644 index 7192890c6..000000000 --- a/surfsense_backend/app/podcasts/rendering/errors.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Failures raised while rendering a transcript to audio.""" - -from __future__ import annotations - - -class RenderError(RuntimeError): - """Rendering could not produce a final audio file. - - Wraps both per-segment synthesis failures and the merge step so the render - task sees one failure type regardless of where it originated. - """ diff --git a/surfsense_backend/app/podcasts/rendering/merge.py b/surfsense_backend/app/podcasts/rendering/merge.py deleted file mode 100644 index 223295349..000000000 --- a/surfsense_backend/app/podcasts/rendering/merge.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Concatenate ordered segment files into a single MP3. - -Uses FFmpeg's concat *demuxer* (a list file of inputs) rather than a -``filter_complex`` graph. The demuxer takes one ``-i`` no matter how many -segments there are, so an hour-long episode with thousands of turns never hits -command-line length limits. Output is always re-encoded to MP3 for a uniform -artifact regardless of the source container (Kokoro WAV or hosted MP3). -""" - -from __future__ import annotations - -from pathlib import Path - -from ffmpeg.asyncio import FFmpeg - -from .errors import RenderError - - -async def concat_to_mp3(segment_paths: list[Path], output_path: Path) -> None: - """Merge ``segment_paths`` in order into ``output_path`` as MP3.""" - if not segment_paths: - raise RenderError("cannot merge an empty list of segments") - - list_file = output_path.with_name(f"{output_path.stem}.concat.txt") - list_file.write_text(_concat_list(segment_paths), encoding="utf-8") - - try: - ffmpeg = ( - FFmpeg() - .option("y") - .input(str(list_file), f="concat", safe=0) - .output(str(output_path), {"c:a": "libmp3lame"}) - ) - await ffmpeg.execute() - except Exception as exc: - raise RenderError(f"audio merge failed: {exc}") from exc - finally: - list_file.unlink(missing_ok=True) - - -def _concat_list(segment_paths: list[Path]) -> str: - # The concat demuxer reads `file '<path>'` lines; single quotes in a path - # are escaped per its quoting rules ('\''). - lines = [] - for path in segment_paths: - escaped = str(path.resolve()).replace("'", "'\\''") - lines.append(f"file '{escaped}'") - return "\n".join(lines) + "\n" diff --git a/surfsense_backend/app/podcasts/rendering/renderer.py b/surfsense_backend/app/podcasts/rendering/renderer.py deleted file mode 100644 index 44071c060..000000000 --- a/surfsense_backend/app/podcasts/rendering/renderer.py +++ /dev/null @@ -1,155 +0,0 @@ -"""Render an approved transcript into a single podcast audio file. - -The renderer is the only place that turns dialogue into sound. It maps each -turn to its speaker's voice, synthesises segments concurrently (capped, served -from the segment cache when possible, and coalesced so identical lines render -once), then merges them in order. It takes a settled spec + transcript and -returns bytes; persistence and lifecycle transitions belong to the service. -""" - -from __future__ import annotations - -import asyncio -from dataclasses import dataclass -from pathlib import Path - -from app.podcasts.schemas import PodcastSpec, Transcript, TranscriptTurn -from app.podcasts.tts import SynthesisRequest, TextToSpeech, TextToSpeechError -from app.podcasts.voices import VoiceCatalog - -from .cache import SegmentCache -from .errors import RenderError -from .merge import concat_to_mp3 - -# Bounds how many segments synthesise at once. Protects hosted-provider rate -# limits and avoids thrashing the local Kokoro pipeline; the renderer is I/O- or -# model-bound per segment, so a small pool already saturates throughput. -DEFAULT_MAX_CONCURRENCY = 4 - -_MERGED_FILENAME = "podcast.mp3" - - -@dataclass(frozen=True, slots=True) -class RenderedPodcast: - """The finished episode: encoded bytes plus their container.""" - - data: bytes - container: str - - -class PodcastRenderer: - """Synthesises and merges a transcript using one TTS provider.""" - - def __init__( - self, - *, - tts: TextToSpeech, - catalog: VoiceCatalog, - max_concurrency: int = DEFAULT_MAX_CONCURRENCY, - ) -> None: - self._tts = tts - self._catalog = catalog - self._max_concurrency = max_concurrency - - async def render( - self, - *, - spec: PodcastSpec, - transcript: Transcript, - workdir: Path, - ) -> RenderedPodcast: - """Produce the merged MP3 for ``transcript`` under ``spec``. - - ``workdir`` holds the segment cache and merge output; reusing the same - directory across renders is what makes voice edits cheap. - """ - cache = SegmentCache(workdir / "segments") - requests = [self._request_for(spec, turn) for turn in transcript.turns] - - # Concurrency primitives are created per render so each call is bound to - # the event loop running it (Celery tasks may use a fresh loop). - synthesizer = _SegmentSynthesizer(self._tts, cache, self._max_concurrency) - segment_paths = await asyncio.gather( - *(synthesizer.segment(request) for request in requests) - ) - - output_path = workdir / _MERGED_FILENAME - await concat_to_mp3(list(segment_paths), output_path) - return RenderedPodcast(data=output_path.read_bytes(), container="mp3") - - def _request_for(self, spec: PodcastSpec, turn: TranscriptTurn) -> SynthesisRequest: - try: - speaker = spec.speaker_for(turn.speaker) - except KeyError as exc: - raise RenderError( - f"transcript references unknown speaker slot {turn.speaker}" - ) from exc - try: - voice = self._catalog.get(speaker.voice_id) - except KeyError as exc: - raise RenderError(f"unknown voice {speaker.voice_id!r}") from exc - return SynthesisRequest( - text=turn.text, voice=voice.native_ref, language=spec.language - ) - - -class _SegmentSynthesizer: - """Per-render synthesis coordinator: caps concurrency and dedupes work. - - Beyond the on-disk cache (which serves cross-render reuse), this coalesces - identical segments that race within one render so the same line is voiced - once even when several turns request it simultaneously. - """ - - def __init__( - self, tts: TextToSpeech, cache: SegmentCache, max_concurrency: int - ) -> None: - self._tts = tts - self._cache = cache - self._container = tts.container - self._semaphore = asyncio.Semaphore(max_concurrency) - self._inflight: dict[str, asyncio.Future[Path]] = {} - self._inflight_lock = asyncio.Lock() - - async def segment(self, request: SynthesisRequest) -> Path: - key = self._cache.key(request) - cached = self._cache.get(key, self._container) - if cached is not None: - return cached - - async with self._inflight_lock: - future = self._inflight.get(key) - owner = future is None - if owner: - future = asyncio.get_event_loop().create_future() - self._inflight[key] = future - - # The owner runs the work and publishes the outcome on the shared future; - # every caller (owner included) reads it back via ``await future`` so the - # result is retrieved exactly once-or-more and never left dangling. - if owner: - try: - path = await self._synthesize(request, key) - except BaseException as exc: - future.set_exception(exc) - else: - future.set_result(path) - finally: - await self._forget(key) - - return await future - - async def _synthesize(self, request: SynthesisRequest, key: str) -> Path: - async with self._semaphore: - cached = self._cache.get(key, self._container) - if cached is not None: - return cached - try: - audio = await self._tts.synthesize(request) - except TextToSpeechError as exc: - raise RenderError(f"segment synthesis failed: {exc}") from exc - return self._cache.put(key, audio.container, audio.data) - - async def _forget(self, key: str) -> None: - async with self._inflight_lock: - self._inflight.pop(key, None) diff --git a/surfsense_backend/app/podcasts/resolution/__init__.py b/surfsense_backend/app/podcasts/resolution/__init__.py deleted file mode 100644 index 19a7edfb3..000000000 --- a/surfsense_backend/app/podcasts/resolution/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Resolution: deterministic default chains for a fresh brief. - -Turns the user's last-used preferences into concrete language and voice -defaults, so the brief gate opens pre-filled and most users approve without -editing. -""" - -from __future__ import annotations - -from .language import ( - DEFAULT_LANGUAGE, - DEFAULT_LANGUAGE_CHAIN, - LanguageContext, - LanguageResolver, - resolve_language, -) -from .voices import VoiceResolutionError, resolve_voices - -__all__ = [ - "DEFAULT_LANGUAGE", - "DEFAULT_LANGUAGE_CHAIN", - "LanguageContext", - "LanguageResolver", - "VoiceResolutionError", - "resolve_language", - "resolve_voices", -] diff --git a/surfsense_backend/app/podcasts/resolution/language.py b/surfsense_backend/app/podcasts/resolution/language.py deleted file mode 100644 index 336d9036b..000000000 --- a/surfsense_backend/app/podcasts/resolution/language.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Resolve the brief's language without spending tokens at the gate. - -The chain mirrors the agreed policy: reuse the language the user last chose, and -otherwise default to English (which the user can still override in the brief). We -deliberately never guess the language from the source content — proposing a -language the user did not ask for is worse than a predictable default. -""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import dataclass - -# What a brand-new user with no signal gets, and what every chain ends on. -DEFAULT_LANGUAGE = "en" - - -@dataclass(frozen=True, slots=True) -class LanguageContext: - """Signals available when proposing a language for a fresh podcast.""" - - last_used: str | None = None - - -class LanguageResolver(ABC): - """One step in the language fallback chain.""" - - @abstractmethod - def resolve(self, context: LanguageContext) -> str | None: - """Return a language tag, or ``None`` to defer to the next resolver.""" - - -class LastUsedLanguage(LanguageResolver): - """Reuse the language from the user's previous podcast.""" - - def resolve(self, context: LanguageContext) -> str | None: - return context.last_used - - -class DefaultLanguage(LanguageResolver): - """Terminal step: always yields the default so the chain never fails.""" - - def resolve(self, context: LanguageContext) -> str | None: - return DEFAULT_LANGUAGE - - -# Order encodes the policy; prepend stronger signals here as they appear. -DEFAULT_LANGUAGE_CHAIN: tuple[LanguageResolver, ...] = ( - LastUsedLanguage(), - DefaultLanguage(), -) - - -def resolve_language( - context: LanguageContext, - chain: tuple[LanguageResolver, ...] = DEFAULT_LANGUAGE_CHAIN, -) -> str: - """Walk ``chain`` and return the first language a resolver yields.""" - for resolver in chain: - language = resolver.resolve(context) - if language: - return language.strip() - # The default resolver guarantees a value; this guards a misconfigured chain. - return DEFAULT_LANGUAGE diff --git a/surfsense_backend/app/podcasts/resolution/voices.py b/surfsense_backend/app/podcasts/resolution/voices.py deleted file mode 100644 index 8d865fbaa..000000000 --- a/surfsense_backend/app/podcasts/resolution/voices.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Assign a default voice to each speaker for the resolved language. - -The default chain reuses the user's previously chosen voices where they are -still valid for the new language/provider, then fills any remaining speakers -with distinct catalog voices (preferring an unused gender so a two-speaker -episode sounds like two people). The user can override any of these in the -brief; this only seeds sensible defaults so most briefs need no edits. -""" - -from __future__ import annotations - -from collections.abc import Sequence - -from app.podcasts.voices import CatalogVoice, TtsProvider, VoiceCatalog - - -class VoiceResolutionError(RuntimeError): - """No catalog voice exists for the requested provider and language.""" - - -def resolve_voices( - *, - catalog: VoiceCatalog, - provider: TtsProvider, - language: str, - speaker_count: int, - preferred: Sequence[str] | None = None, -) -> list[CatalogVoice]: - """Return one :class:`CatalogVoice` per speaker, in slot order. - - ``preferred`` is the user's last-used voice ids (by slot); any that no - longer fit the provider/language are silently dropped and replaced. - """ - if speaker_count < 1: - raise ValueError("speaker_count must be >= 1") - - available = catalog.for_language(provider, language) - if not available: - raise VoiceResolutionError( - f"{provider.value} has no voice for language {language!r}" - ) - - preferred = preferred or () - by_id = {voice.voice_id: voice for voice in available} - - assignment: list[CatalogVoice] = [] - used_ids: set[str] = set() - used_genders: set = set() - - for slot in range(speaker_count): - reuse_id = preferred[slot] if slot < len(preferred) else None - if reuse_id and reuse_id in by_id and reuse_id not in used_ids: - voice = by_id[reuse_id] - else: - voice = _pick_distinct(available, used_ids, used_genders) - assignment.append(voice) - used_ids.add(voice.voice_id) - used_genders.add(voice.gender) - - return assignment - - -def _pick_distinct( - available: list[CatalogVoice], - used_ids: set[str], - used_genders: set, -) -> CatalogVoice: - """Pick a fresh voice, preferring an unused gender, then any unused voice. - - Falls back to the first catalog voice when speakers outnumber distinct - voices, so resolution always assigns every speaker rather than failing. - """ - fresh = [v for v in available if v.voice_id not in used_ids] - if fresh: - for voice in fresh: - if voice.gender not in used_genders: - return voice - return fresh[0] - return available[0] diff --git a/surfsense_backend/app/podcasts/schemas/__init__.py b/surfsense_backend/app/podcasts/schemas/__init__.py deleted file mode 100644 index cd19a21cc..000000000 --- a/surfsense_backend/app/podcasts/schemas/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Pydantic shapes for the podcast brief and transcript.""" - -from __future__ import annotations - -from .spec import ( - DurationTarget, - PodcastSpec, - PodcastStyle, - SpeakerRole, - SpeakerSpec, - normalize_language_tag, -) -from .transcript import Transcript, TranscriptTurn - -__all__ = [ - "DurationTarget", - "PodcastSpec", - "PodcastStyle", - "SpeakerRole", - "SpeakerSpec", - "Transcript", - "TranscriptTurn", - "normalize_language_tag", -] diff --git a/surfsense_backend/app/podcasts/schemas/spec.py b/surfsense_backend/app/podcasts/schemas/spec.py deleted file mode 100644 index 1ef3dcfff..000000000 --- a/surfsense_backend/app/podcasts/schemas/spec.py +++ /dev/null @@ -1,166 +0,0 @@ -"""The brief: the editable configuration a user approves before drafting. - -A :class:`PodcastSpec` front-loads every decision that drives token or audio -cost (language, speakers, voices, style, target length) so the expensive -drafting and rendering steps run once against settled inputs. It is stored as -JSONB on the ``podcasts`` row and round-trips through the review API. -""" - -from __future__ import annotations - -import re -from enum import StrEnum - -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator - -# A speaker count beyond this is almost never a real podcast and explodes the -# voice/turn-attribution space, so we reject it at the brief gate. -MAX_SPEAKERS = 6 - -# Long-form is a goal, but an open-ended upper bound invites runaway TTS bills. -# One day of audio is a generous ceiling that still blocks obvious mistakes. -MAX_DURATION_MINUTES = 24 * 60 - -# BCP-47 primary subtag plus optional region (e.g. ``en``, ``en-US``, ``pt-BR``). -# Kept deliberately permissive: the voice catalog, not the brief, decides which -# languages can actually be synthesised. Casing is normalised after matching. -_LANGUAGE_TAG = re.compile(r"^[A-Za-z]{2,3}(-[A-Za-z0-9]{2,8})*$") - - -def normalize_language_tag(value: str) -> str: - """Validate and canonicalise a BCP-47 tag (lowercased primary subtag). - - Shared with the generation layer so resolved and user-entered languages are - normalised identically before they reach a :class:`PodcastSpec`. - """ - cleaned = value.strip() - if not _LANGUAGE_TAG.match(cleaned): - raise ValueError(f"not a valid BCP-47 language tag: {value!r}") - primary, _, rest = cleaned.partition("-") - return primary.lower() if not rest else f"{primary.lower()}-{rest}" - - -class SpeakerRole(StrEnum): - """How a speaker functions in the conversation, used to steer drafting.""" - - HOST = "host" - COHOST = "cohost" - GUEST = "guest" - EXPERT = "expert" - NARRATOR = "narrator" - - -class PodcastStyle(StrEnum): - """The conversational format the transcript should follow.""" - - CONVERSATIONAL = "conversational" - INTERVIEW = "interview" - DEBATE = "debate" - MONOLOGUE = "monologue" - NARRATIVE = "narrative" - - -class SpeakerSpec(BaseModel): - """One voice in the podcast: who they are and which TTS voice renders them. - - ``slot`` is the stable join key. Transcript turns reference a speaker by - ``slot`` and the renderer resolves ``voice_id`` for that same slot, so the - two never drift even if speakers are reordered in the brief. - """ - - model_config = ConfigDict(extra="forbid") - - slot: int = Field( - ..., ge=0, description="Stable index a transcript turn references" - ) - name: str = Field(..., min_length=1, max_length=120) - role: SpeakerRole - voice_id: str = Field( - ..., - min_length=1, - description="Catalog voice id valid for the spec's language and provider", - ) - - @field_validator("name", "voice_id") - @classmethod - def _strip_required_text(cls, value: str) -> str: - cleaned = value.strip() - if not cleaned: - raise ValueError("must not be blank") - return cleaned - - -class DurationTarget(BaseModel): - """The desired finished length as an inclusive minute range. - - Drafting aims for the midpoint and treats the bounds as soft guardrails; - storing a range (rather than a point) keeps long-form expectations honest - without pretending we can hit an exact runtime. - """ - - model_config = ConfigDict(extra="forbid") - - min_minutes: int = Field(..., ge=1, le=MAX_DURATION_MINUTES) - max_minutes: int = Field(..., ge=1, le=MAX_DURATION_MINUTES) - - @model_validator(mode="after") - def _check_order(self) -> DurationTarget: - if self.max_minutes < self.min_minutes: - raise ValueError("max_minutes must be >= min_minutes") - return self - - @property - def midpoint_minutes(self) -> float: - """The runtime drafting should aim for within the range.""" - return (self.min_minutes + self.max_minutes) / 2 - - -class PodcastSpec(BaseModel): - """The full brief approved before any tokens or audio are spent.""" - - model_config = ConfigDict(extra="forbid") - - language: str = Field(..., description="BCP-47 tag, e.g. 'en', 'en-US', 'pt-BR'") - style: PodcastStyle = PodcastStyle.CONVERSATIONAL - speakers: list[SpeakerSpec] = Field(..., min_length=1, max_length=MAX_SPEAKERS) - duration: DurationTarget - focus: str | None = Field( - default=None, - max_length=2000, - description="Optional user steer for what the episode should emphasise", - ) - - @field_validator("language") - @classmethod - def _normalise_language(cls, value: str) -> str: - return normalize_language_tag(value) - - @field_validator("focus") - @classmethod - def _blank_focus_is_none(cls, value: str | None) -> str | None: - if value is None: - return None - cleaned = value.strip() - return cleaned or None - - @model_validator(mode="after") - def _check_speaker_slots(self) -> PodcastSpec: - slots = [speaker.slot for speaker in self.speakers] - if len(slots) != len(set(slots)): - raise ValueError("speaker slots must be unique") - return self - - @model_validator(mode="after") - def _check_style_speakers(self) -> PodcastSpec: - # One voice is what "monologue" means; letting extra speakers through - # would force drafting to silently pick a winner. - if self.style is PodcastStyle.MONOLOGUE and len(self.speakers) != 1: - raise ValueError("a monologue has exactly one speaker") - return self - - def speaker_for(self, slot: int) -> SpeakerSpec: - """Return the speaker bound to ``slot`` or raise if none matches.""" - for speaker in self.speakers: - if speaker.slot == slot: - return speaker - raise KeyError(f"no speaker for slot {slot}") diff --git a/surfsense_backend/app/podcasts/schemas/transcript.py b/surfsense_backend/app/podcasts/schemas/transcript.py deleted file mode 100644 index b4c1463d8..000000000 --- a/surfsense_backend/app/podcasts/schemas/transcript.py +++ /dev/null @@ -1,41 +0,0 @@ -"""The transcript: ordered dialogue turns drafting produces for review. - -A :class:`Transcript` is the reviewable artifact at the go/no-go gate and the -exact input the renderer turns into audio. Each turn names a speaker by the -``slot`` defined in the :class:`~app.podcasts.schemas.spec.PodcastSpec`, so the -renderer can resolve the right voice without re-attributing anything. -""" - -from __future__ import annotations - -from pydantic import BaseModel, ConfigDict, Field, field_validator - - -class TranscriptTurn(BaseModel): - """A single spoken line by one speaker.""" - - model_config = ConfigDict(extra="forbid") - - speaker: int = Field(..., ge=0, description="The PodcastSpec speaker slot speaking") - text: str = Field(..., min_length=1) - - @field_validator("text") - @classmethod - def _strip_text(cls, value: str) -> str: - cleaned = value.strip() - if not cleaned: - raise ValueError("turn text must not be blank") - return cleaned - - -class Transcript(BaseModel): - """The full ordered dialogue for an episode.""" - - model_config = ConfigDict(extra="forbid") - - turns: list[TranscriptTurn] = Field(..., min_length=1) - - @property - def word_count(self) -> int: - """Total spoken words, used to estimate runtime against the brief.""" - return sum(len(turn.text.split()) for turn in self.turns) diff --git a/surfsense_backend/app/podcasts/service.py b/surfsense_backend/app/podcasts/service.py deleted file mode 100644 index 165bc77a4..000000000 --- a/surfsense_backend/app/podcasts/service.py +++ /dev/null @@ -1,255 +0,0 @@ -"""The podcast lifecycle authority: every status change goes through here. - -The service owns the state machine. Each method names a real lifecycle step, -validates it against the allowed-transition table, and (de)serializes the brief -and transcript to/from their JSONB columns. It deliberately does not enqueue -Celery work — callers transition the row here, then schedule the next task — so -the rules stay testable and free of task-queue coupling. -""" - -from __future__ import annotations - -from sqlalchemy.ext.asyncio import AsyncSession - -from app.podcasts.persistence import Podcast, PodcastRepository, PodcastStatus -from app.podcasts.schemas import PodcastSpec, Transcript, TranscriptTurn - -_MAX_ERROR_CHARS = 2000 - -# The only status changes the machine permits. Terminal states have no exits. -_ALLOWED: dict[PodcastStatus, frozenset[PodcastStatus]] = { - PodcastStatus.PENDING: frozenset( - {PodcastStatus.AWAITING_BRIEF, PodcastStatus.FAILED, PodcastStatus.CANCELLED} - ), - # The READY exits below exist for reverting a regeneration; the audio - # guard for that lives in revert_regeneration. - PodcastStatus.AWAITING_BRIEF: frozenset( - { - PodcastStatus.DRAFTING, - PodcastStatus.READY, - PodcastStatus.FAILED, - PodcastStatus.CANCELLED, - } - ), - PodcastStatus.DRAFTING: frozenset( - { - PodcastStatus.RENDERING, - PodcastStatus.READY, - PodcastStatus.FAILED, - PodcastStatus.CANCELLED, - } - ), - # Never entered anymore (the transcript gate was dropped); kept with exits - # so legacy rows aren't stranded. - PodcastStatus.AWAITING_REVIEW: frozenset( - {PodcastStatus.AWAITING_BRIEF, PodcastStatus.FAILED, PodcastStatus.CANCELLED} - ), - PodcastStatus.RENDERING: frozenset( - {PodcastStatus.READY, PodcastStatus.FAILED, PodcastStatus.CANCELLED} - ), - # Not terminal: regeneration reopens the brief gate so the user can tweak - # the spec before a new take is drafted. - PodcastStatus.READY: frozenset({PodcastStatus.AWAITING_BRIEF}), - PodcastStatus.FAILED: frozenset(), - PodcastStatus.CANCELLED: frozenset(), -} - - -class PodcastError(RuntimeError): - """Base class for lifecycle errors.""" - - -class InvalidTransitionError(PodcastError): - """A requested status change is not permitted from the current state.""" - - -class SpecConflictError(PodcastError): - """A spec edit raced another: the expected version is stale.""" - - def __init__(self, expected: int, actual: int) -> None: - super().__init__( - f"spec version conflict: expected {expected}, current is {actual}" - ) - self.expected = expected - self.actual = actual - - -class PreconditionFailedError(PodcastError): - """A transition's data precondition (brief/transcript present) is unmet.""" - - -class PodcastService: - """Drives one podcast through its lifecycle within a single session.""" - - def __init__(self, session: AsyncSession) -> None: - self._session = session - self._repo = PodcastRepository(session) - - async def create( - self, *, title: str, search_space_id: int, thread_id: int | None = None - ) -> Podcast: - """Create a fresh podcast in ``PENDING`` awaiting its brief.""" - podcast = Podcast( - title=title, - search_space_id=search_space_id, - thread_id=thread_id, - status=PodcastStatus.PENDING, - spec_version=1, - ) - return await self._repo.add(podcast) - - async def attach_brief(self, podcast: Podcast, spec: PodcastSpec) -> Podcast: - """Record the proposed brief and open the review gate.""" - self._transition(podcast, PodcastStatus.AWAITING_BRIEF) - podcast.spec = spec.model_dump(mode="json") - await self._session.flush() - return podcast - - async def update_spec( - self, podcast: Podcast, spec: PodcastSpec, expected_version: int - ) -> Podcast: - """Edit the brief at the gate, guarded by optimistic concurrency.""" - if _status(podcast) is not PodcastStatus.AWAITING_BRIEF: - raise InvalidTransitionError( - f"the brief can only be edited while awaiting_brief, " - f"not {_status(podcast).value}" - ) - if expected_version != podcast.spec_version: - raise SpecConflictError(expected_version, podcast.spec_version) - podcast.spec = spec.model_dump(mode="json") - podcast.spec_version += 1 - await self._session.flush() - return podcast - - async def begin_drafting(self, podcast: Podcast) -> Podcast: - """Approve the brief and start transcript drafting.""" - if podcast.spec is None: - raise PreconditionFailedError("cannot draft without a brief") - self._transition(podcast, PodcastStatus.DRAFTING) - await self._session.flush() - return podcast - - async def attach_transcript( - self, podcast: Podcast, transcript: Transcript - ) -> Podcast: - """Record the drafted transcript and move straight to rendering.""" - self._transition(podcast, PodcastStatus.RENDERING) - podcast.podcast_transcript = transcript.model_dump(mode="json") - await self._session.flush() - return podcast - - # Guards regenerate beyond the transition table: from PENDING the - # AWAITING_BRIEF target is also legal, but there it means attaching a brief. - _REGENERABLE = frozenset({PodcastStatus.READY, PodcastStatus.AWAITING_REVIEW}) - - async def regenerate(self, podcast: Podcast) -> Podcast: - """Reopen the brief gate; the saved spec becomes the new starting point.""" - if _status(podcast) not in self._REGENERABLE: - raise InvalidTransitionError( - f"nothing to regenerate from {_status(podcast).value}" - ) - # Legacy episodes finished before briefs existed; a gate with nothing - # to review would strand them. - if podcast.spec is None: - raise PreconditionFailedError("cannot regenerate without a brief") - self._transition(podcast, PodcastStatus.AWAITING_BRIEF) - await self._session.flush() - return podcast - - async def revert_regeneration(self, podcast: Podcast) -> Podcast: - """Back out of a regeneration and fall back to the stored episode. - - Regeneration keeps the rendered audio until a new take replaces it, so - any point before that commit is a free change of mind. A fresh podcast - has no regeneration to revert and is rejected. - """ - if not has_stored_episode(podcast): - raise InvalidTransitionError("no finished episode to fall back to") - self._transition(podcast, PodcastStatus.READY) - await self._session.flush() - return podcast - - async def attach_audio( - self, - podcast: Podcast, - *, - storage_backend: str, - storage_key: str, - duration_seconds: int | None = None, - ) -> Podcast: - """Record rendered audio and mark the podcast ready.""" - self._transition(podcast, PodcastStatus.READY) - podcast.storage_backend = storage_backend - podcast.storage_key = storage_key - podcast.duration_seconds = duration_seconds - podcast.error = None - await self._session.flush() - return podcast - - async def fail(self, podcast: Podcast, error: str) -> Podcast: - """Move a non-terminal podcast to ``FAILED`` with a reason.""" - self._transition(podcast, PodcastStatus.FAILED) - podcast.error = (error or "")[:_MAX_ERROR_CHARS] or None - await self._session.flush() - return podcast - - async def cancel(self, podcast: Podcast) -> Podcast: - """Cancel a podcast that has produced nothing the user could keep. - - No user action may destroy playable audio: once an episode exists, - backing out goes through revert_regeneration instead. - """ - if has_stored_episode(podcast): - raise InvalidTransitionError( - "a finished episode exists; revert the regeneration instead" - ) - self._transition(podcast, PodcastStatus.CANCELLED) - await self._session.flush() - return podcast - - def _transition(self, podcast: Podcast, target: PodcastStatus) -> None: - current = _status(podcast) - if target not in _ALLOWED[current]: - raise InvalidTransitionError( - f"{current.value} -> {target.value} is not allowed" - ) - podcast.status = target - - -def _status(podcast: Podcast) -> PodcastStatus: - return PodcastStatus(podcast.status) - - -def has_stored_episode(podcast: Podcast) -> bool: - """Whether finished audio is stored (``file_location`` covers legacy rows).""" - return bool(podcast.storage_key or podcast.file_location) - - -def read_spec(podcast: Podcast) -> PodcastSpec | None: - """Deserialize the stored brief, or ``None`` if not yet proposed.""" - return PodcastSpec.model_validate(podcast.spec) if podcast.spec else None - - -def read_transcript(podcast: Podcast) -> Transcript | None: - """Deserialize the stored transcript, or ``None`` if not yet drafted.""" - raw = podcast.podcast_transcript - if not raw: - return None - # Rows from before the lifecycle rework stored a bare turn list with - # different field names; they must keep reading, not fail validation. - if isinstance(raw, list): - return Transcript( - turns=[ - TranscriptTurn(speaker=turn["speaker_id"], text=turn["dialog"]) - for turn in raw - ] - ) - return Transcript.model_validate(raw) - - -def preferences_from(podcast: Podcast | None) -> tuple[str | None, list[str]]: - """Extract reusable (language, voice_ids) defaults from a prior podcast.""" - spec = read_spec(podcast) if podcast is not None else None - if spec is None: - return None, [] - return spec.language, [speaker.voice_id for speaker in spec.speakers] diff --git a/surfsense_backend/app/podcasts/storage.py b/surfsense_backend/app/podcasts/storage.py deleted file mode 100644 index f02429dff..000000000 --- a/surfsense_backend/app/podcasts/storage.py +++ /dev/null @@ -1,53 +0,0 @@ -"""Durable storage for rendered podcast audio. - -Wraps the shared :class:`StorageBackend` so the rest of the module never deals -with object keys directly. Audio is stored under a per-podcast key, streamed for -download, and purged when a podcast is deleted. -""" - -from __future__ import annotations - -import uuid -from collections.abc import AsyncIterator - -from app.file_storage.factory import get_storage_backend -from app.podcasts.persistence import Podcast - -_AUDIO_CONTENT_TYPE = "audio/mpeg" - - -def build_audio_key(*, search_space_id: int, podcast_id: int) -> str: - """Object key for a podcast's audio. - - Shape: ``podcasts/{search_space_id}/{podcast_id}/{uuid}.mp3``. The uuid lets - a re-render write a fresh object before the old one is purged. - """ - return f"podcasts/{search_space_id}/{podcast_id}/{uuid.uuid4().hex}.mp3" - - -async def store_audio( - *, search_space_id: int, podcast_id: int, data: bytes -) -> tuple[str, str]: - """Persist audio bytes and return ``(backend_name, storage_key)``.""" - backend = get_storage_backend() - key = build_audio_key(search_space_id=search_space_id, podcast_id=podcast_id) - await backend.put(key, data, content_type=_AUDIO_CONTENT_TYPE) - return backend.backend_name, key - - -def open_audio_stream(podcast: Podcast) -> AsyncIterator[bytes]: - """Stream a ready podcast's audio bytes. Raises if it has none.""" - if not podcast.storage_key: - raise FileNotFoundError(f"podcast {podcast.id} has no stored audio") - return get_storage_backend().open_stream(podcast.storage_key) - - -async def purge_audio(podcast: Podcast) -> None: - """Delete a podcast's stored audio if present; a missing object is fine.""" - await purge_audio_object(podcast.storage_key) - - -async def purge_audio_object(key: str | None) -> None: - """Delete a stored audio object by key, e.g. the one a re-render replaced.""" - if key: - await get_storage_backend().delete(key) diff --git a/surfsense_backend/app/podcasts/tasks/__init__.py b/surfsense_backend/app/podcasts/tasks/__init__.py deleted file mode 100644 index cd0b7e4c4..000000000 --- a/surfsense_backend/app/podcasts/tasks/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Celery tasks driving the podcast lifecycle across its expensive phases. - -One task per heavy async phase: draft the transcript (LLM) and render the audio -(TTS). The brief is deterministic and proposed inline at create time, so it has -no task. Each task is enqueued by the API after it performs the guarded status -transition, and each pushes its result onto the row for the frontend to observe. -""" - -from __future__ import annotations - -from .draft import draft_transcript_task -from .render import render_audio_task - -__all__ = [ - "draft_transcript_task", - "render_audio_task", -] diff --git a/surfsense_backend/app/podcasts/tasks/draft.py b/surfsense_backend/app/podcasts/tasks/draft.py deleted file mode 100644 index c5b489571..000000000 --- a/surfsense_backend/app/podcasts/tasks/draft.py +++ /dev/null @@ -1,100 +0,0 @@ -"""Transcript-drafting task: DRAFTING -> RENDERING. - -The expensive, LLM-heavy step, so it runs under ``billable_call``. The API has -already moved the row to DRAFTING and stored the approved brief; this task -drafts the long-form transcript and chains straight into the render — the brief -gate is the only approval in the lifecycle. -""" - -from __future__ import annotations - -import logging - -from app.celery_app import celery_app -from app.config import config as app_config -from app.podcasts.generation.transcript.graph import graph as transcript_graph -from app.podcasts.generation.transcript.state import TranscriptState -from app.podcasts.persistence import PodcastRepository -from app.podcasts.service import PodcastService, read_spec -from app.services.billable_calls import ( - BillingSettlementError, - QuotaInsufficientError, - _resolve_agent_billing_for_search_space, - billable_call, -) -from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task - -from .render import render_audio_task -from .runtime import billable_session, mark_failed - -logger = logging.getLogger(__name__) - - -@celery_app.task(name="podcast.draft_transcript", bind=True) -def draft_transcript_task(self, podcast_id: int, search_space_id: int) -> dict: - try: - return run_async_celery_task( - lambda: _draft_transcript(podcast_id, search_space_id) - ) - except Exception as exc: - logger.error("Podcast %s drafting failed: %s", podcast_id, exc) - message = str(exc) - run_async_celery_task(lambda: mark_failed(podcast_id, message)) - return {"status": "failed", "podcast_id": podcast_id} - - -async def _draft_transcript(podcast_id: int, search_space_id: int) -> dict: - async with get_celery_session_maker()() as session: - repo = PodcastRepository(session) - service = PodcastService(session) - podcast = await repo.get(podcast_id) - if podcast is None: - raise ValueError(f"podcast {podcast_id} not found") - - spec = read_spec(podcast) - if spec is None: - raise ValueError(f"podcast {podcast_id} has no approved brief") - - owner_id, tier, base_model = await _resolve_agent_billing_for_search_space( - session, search_space_id, thread_id=podcast.thread_id - ) - - state = TranscriptState( - db_session=session, source_content=podcast.source_content or "" - ) - config = { - "configurable": { - "search_space_id": search_space_id, - "spec": spec, - "focus": spec.focus, - } - } - - try: - async with billable_call( - user_id=owner_id, - search_space_id=search_space_id, - billing_tier=tier, - base_model=base_model, - quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS, - usage_type="podcast_generation", - call_details={"podcast_id": podcast_id, "title": podcast.title}, - billable_session_factory=billable_session, - ): - result = await transcript_graph.ainvoke(state, config=config) - except QuotaInsufficientError: - await service.fail(podcast, "premium quota exhausted") - await session.commit() - return {"status": "failed", "podcast_id": podcast_id, "reason": "quota"} - except BillingSettlementError: - await service.fail(podcast, "billing settlement failed") - await session.commit() - return {"status": "failed", "podcast_id": podcast_id, "reason": "billing"} - - await service.attach_transcript(podcast, result["transcript"]) - await session.commit() - - # Enqueue only after the transaction is committed, so the render worker can - # never pick up a row whose transcript isn't visible yet. - render_audio_task.delay(podcast_id) - return {"status": "rendering", "podcast_id": podcast_id} diff --git a/surfsense_backend/app/podcasts/tasks/render.py b/surfsense_backend/app/podcasts/tasks/render.py deleted file mode 100644 index 2e550a868..000000000 --- a/surfsense_backend/app/podcasts/tasks/render.py +++ /dev/null @@ -1,88 +0,0 @@ -"""Audio-rendering task: RENDERING -> READY. - -Synthesises and merges the approved transcript, stores the MP3 in the object -store, and marks the podcast ready. The working directory is stable per podcast -so a re-render (e.g. after a voice change) reuses the segment cache. -""" - -from __future__ import annotations - -import logging -import tempfile -from pathlib import Path - -from app.celery_app import celery_app -from app.podcasts.persistence import PodcastRepository -from app.podcasts.rendering import PodcastRenderer -from app.podcasts.service import ( - InvalidTransitionError, - PodcastService, - read_spec, - read_transcript, -) -from app.podcasts.storage import purge_audio_object, store_audio -from app.podcasts.tts import get_text_to_speech -from app.podcasts.voices import get_voice_catalog -from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task - -from .runtime import mark_failed - -logger = logging.getLogger(__name__) - -_WORKDIR_BASE = Path(tempfile.gettempdir()) / "surfsense_podcasts" - - -@celery_app.task(name="podcast.render_audio", bind=True) -def render_audio_task(self, podcast_id: int) -> dict: - try: - return run_async_celery_task(lambda: _render_audio(podcast_id)) - except Exception as exc: - logger.error("Podcast %s render failed: %s", podcast_id, exc) - message = str(exc) - run_async_celery_task(lambda: mark_failed(podcast_id, message)) - return {"status": "failed", "podcast_id": podcast_id} - - -async def _render_audio(podcast_id: int) -> dict: - async with get_celery_session_maker()() as session: - repo = PodcastRepository(session) - podcast = await repo.get(podcast_id) - if podcast is None: - raise ValueError(f"podcast {podcast_id} not found") - - spec = read_spec(podcast) - transcript = read_transcript(podcast) - if spec is None or transcript is None: - raise ValueError(f"podcast {podcast_id} is missing brief or transcript") - - renderer = PodcastRenderer( - tts=get_text_to_speech(), catalog=get_voice_catalog() - ) - workdir = _WORKDIR_BASE / str(podcast_id) - workdir.mkdir(parents=True, exist_ok=True) - rendered = await renderer.render( - spec=spec, transcript=transcript, workdir=workdir - ) - - superseded_key = podcast.storage_key - - backend_name, key = await store_audio( - search_space_id=podcast.search_space_id, - podcast_id=podcast_id, - data=rendered.data, - ) - try: - await PodcastService(session).attach_audio( - podcast, storage_backend=backend_name, storage_key=key - ) - await session.commit() - except InvalidTransitionError: - # A user back-out won the race (e.g. the regeneration was - # reverted): drop the stale render and leave the row alone. - await purge_audio_object(key) - return {"status": "superseded", "podcast_id": podcast_id} - - # Purge only after the new audio is committed, so a failed re-render never - # destroys the episode the user can still play. - await purge_audio_object(superseded_key) - return {"status": "ready", "podcast_id": podcast_id} diff --git a/surfsense_backend/app/podcasts/tasks/runtime.py b/surfsense_backend/app/podcasts/tasks/runtime.py deleted file mode 100644 index 349aeffb2..000000000 --- a/surfsense_backend/app/podcasts/tasks/runtime.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Shared plumbing for the podcast Celery tasks. - -Each task runs its async body via :func:`run_async_celery_task` and, on any -failure, records the reason on the row through the lifecycle service. Marking -failed is best-effort: a podcast that already reached a terminal state is left -untouched rather than forced. -""" - -from __future__ import annotations - -import logging -from contextlib import asynccontextmanager - -from app.podcasts.persistence import PodcastRepository -from app.podcasts.service import PodcastError, PodcastService -from app.tasks.celery_tasks import get_celery_session_maker - -logger = logging.getLogger(__name__) - - -@asynccontextmanager -async def billable_session(): - """Session factory for ``billable_call`` inside the worker loop.""" - async with get_celery_session_maker()() as session: - yield session - - -async def mark_failed(podcast_id: int, error: str) -> None: - """Best-effort: move a non-terminal podcast to FAILED with ``error``.""" - async with get_celery_session_maker()() as session: - repo = PodcastRepository(session) - podcast = await repo.get(podcast_id) - if podcast is None: - return - try: - await PodcastService(session).fail(podcast, error) - await session.commit() - except PodcastError: - # Already terminal (e.g. cancelled): nothing to record. - logger.info("Podcast %s already terminal; not marking failed", podcast_id) diff --git a/surfsense_backend/app/podcasts/tts/__init__.py b/surfsense_backend/app/podcasts/tts/__init__.py deleted file mode 100644 index 16379dc2b..000000000 --- a/surfsense_backend/app/podcasts/tts/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Text-to-speech: a per-segment synthesis port with provider adapters. - -Callers depend on :class:`TextToSpeech` and obtain the configured provider from -:func:`get_text_to_speech`; the concrete Kokoro/LiteLLM adapters stay private. -""" - -from __future__ import annotations - -from .audio import SynthesizedAudio -from .errors import TextToSpeechError -from .factory import get_text_to_speech -from .port import TextToSpeech -from .request import SynthesisRequest, VoiceRef - -__all__ = [ - "SynthesisRequest", - "SynthesizedAudio", - "TextToSpeech", - "TextToSpeechError", - "VoiceRef", - "get_text_to_speech", -] diff --git a/surfsense_backend/app/podcasts/tts/adapters/__init__.py b/surfsense_backend/app/podcasts/tts/adapters/__init__.py deleted file mode 100644 index 24d517e55..000000000 --- a/surfsense_backend/app/podcasts/tts/adapters/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Per-provider TextToSpeech implementations.""" - -from __future__ import annotations diff --git a/surfsense_backend/app/podcasts/tts/adapters/kokoro.py b/surfsense_backend/app/podcasts/tts/adapters/kokoro.py deleted file mode 100644 index 2ef0069c5..000000000 --- a/surfsense_backend/app/podcasts/tts/adapters/kokoro.py +++ /dev/null @@ -1,109 +0,0 @@ -"""Local Kokoro adapter: on-box synthesis, no network or per-segment cost. - -Kokoro selects its language model by a single-letter ``lang_code``, so this -adapter maps the brief's BCP-47 tag to that code and caches one pipeline per -code (pipeline construction loads weights and is expensive). Pipelines run in a -thread pool because Kokoro is synchronous; the renderer caps how many segments -synthesise at once. -""" - -from __future__ import annotations - -import asyncio -import io -from typing import TYPE_CHECKING - -from ..audio import SynthesizedAudio -from ..errors import TextToSpeechError -from ..port import TextToSpeech -from ..request import SynthesisRequest - -if TYPE_CHECKING: - from kokoro import KPipeline - -# Kokoro emits 24 kHz mono PCM regardless of voice. -_SAMPLE_RATE = 24000 - -# BCP-47 primary subtag -> Kokoro language code. English defaults to American; -# the en-GB region override below switches it to British. -_LANG_CODE_BY_PRIMARY = { - "en": "a", - "es": "e", - "fr": "f", - "hi": "h", - "it": "i", - "ja": "j", - "pt": "p", - "zh": "z", -} - - -class KokoroTextToSpeech(TextToSpeech): - """Synthesises segments with locally hosted Kokoro pipelines.""" - - def __init__(self) -> None: - self._pipelines: dict[str, KPipeline] = {} - - @property - def container(self) -> str: - return "wav" - - async def synthesize(self, request: SynthesisRequest) -> SynthesizedAudio: - if not isinstance(request.voice, str): - raise TextToSpeechError("Kokoro voices are named by string, not a mapping") - - pipeline = self._pipeline_for(request.language) - loop = asyncio.get_event_loop() - try: - generator = await loop.run_in_executor( - None, - lambda: pipeline( - request.text, - voice=request.voice, - speed=request.speed, - split_pattern=r"\n+", - ), - ) - segments = [audio for _gs, _ps, audio in generator] - except Exception as exc: - raise TextToSpeechError(f"Kokoro synthesis failed: {exc}") from exc - - if not segments: - raise TextToSpeechError("Kokoro produced no audio for the text") - - return SynthesizedAudio( - data=_encode_wav(segments, _SAMPLE_RATE), - container="wav", - sample_rate=_SAMPLE_RATE, - ) - - def _pipeline_for(self, language: str) -> KPipeline: - lang_code = _lang_code(language) - pipeline = self._pipelines.get(lang_code) - if pipeline is None: - from kokoro import KPipeline - - pipeline = KPipeline(lang_code=lang_code) - self._pipelines[lang_code] = pipeline - return pipeline - - -def _lang_code(language: str) -> str: - normalised = language.strip().lower() - if normalised.startswith("en-gb") or normalised == "en-uk": - return "b" - primary = normalised.partition("-")[0] - code = _LANG_CODE_BY_PRIMARY.get(primary) - if code is None: - raise TextToSpeechError(f"Kokoro has no language model for {language!r}") - return code - - -def _encode_wav(segments: list, sample_rate: int) -> bytes: - import numpy as np - import soundfile as sf - - waveform = segments[0] if len(segments) == 1 else np.concatenate(segments) - buffer = io.BytesIO() - sf.write(buffer, waveform, sample_rate, format="WAV") - return buffer.getvalue() diff --git a/surfsense_backend/app/podcasts/tts/adapters/litellm.py b/surfsense_backend/app/podcasts/tts/adapters/litellm.py deleted file mode 100644 index d0014c5cd..000000000 --- a/surfsense_backend/app/podcasts/tts/adapters/litellm.py +++ /dev/null @@ -1,67 +0,0 @@ -"""LiteLLM adapter: hosted TTS (OpenAI, Azure, Vertex AI) via one ``aspeech`` call. - -LiteLLM normalises every hosted provider behind the same ``aspeech`` surface, -so a single adapter covers them all. The provider is encoded in the model -string (e.g. ``openai/tts-1``, ``vertex_ai/...``) and the voice reference is -whatever that provider expects, which the catalog already supplies. -""" - -from __future__ import annotations - -from ..audio import SynthesizedAudio -from ..errors import TextToSpeechError -from ..port import TextToSpeech -from ..request import SynthesisRequest - -# Hosted providers return MP3-encoded bytes from ``aspeech``. -_CONTAINER = "mp3" - -# A long single segment still finishes well under this; retries absorb transient -# upstream failures without failing the whole render. -_TIMEOUT_SECONDS = 600 -_MAX_RETRIES = 2 - - -class LiteLlmTextToSpeech(TextToSpeech): - """Synthesises segments through any LiteLLM-supported hosted TTS model.""" - - def __init__( - self, - *, - model: str, - api_base: str | None = None, - api_key: str | None = None, - ) -> None: - self._model = model - self._api_base = api_base - self._api_key = api_key - - @property - def container(self) -> str: - return _CONTAINER - - async def synthesize(self, request: SynthesisRequest) -> SynthesizedAudio: - from litellm import aspeech - - kwargs = { - "model": self._model, - "voice": request.voice, - "input": request.text, - "max_retries": _MAX_RETRIES, - "timeout": _TIMEOUT_SECONDS, - } - if self._api_base: - kwargs["api_base"] = self._api_base - if self._api_key: - kwargs["api_key"] = self._api_key - - try: - response = await aspeech(**kwargs) - except Exception as exc: - raise TextToSpeechError(f"{self._model} synthesis failed: {exc}") from exc - - data = getattr(response, "content", None) - if not data: - raise TextToSpeechError(f"{self._model} returned no audio") - - return SynthesizedAudio(data=data, container=_CONTAINER) diff --git a/surfsense_backend/app/podcasts/tts/audio.py b/surfsense_backend/app/podcasts/tts/audio.py deleted file mode 100644 index f3c79dd5a..000000000 --- a/surfsense_backend/app/podcasts/tts/audio.py +++ /dev/null @@ -1,19 +0,0 @@ -"""The bytes a TTS provider returns for one segment.""" - -from __future__ import annotations - -from dataclasses import dataclass - - -@dataclass(frozen=True, slots=True) -class SynthesizedAudio: - """Encoded audio for a single segment, ready to cache and concatenate. - - ``container`` is the file extension the bytes are encoded as (``"wav"`` or - ``"mp3"``); the renderer uses it to name the on-disk segment so FFmpeg can - demux the right format during merge. - """ - - data: bytes - container: str - sample_rate: int | None = None diff --git a/surfsense_backend/app/podcasts/tts/errors.py b/surfsense_backend/app/podcasts/tts/errors.py deleted file mode 100644 index 8e7ec3f2b..000000000 --- a/surfsense_backend/app/podcasts/tts/errors.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Failures raised by the TTS layer.""" - -from __future__ import annotations - - -class TextToSpeechError(RuntimeError): - """A provider failed to synthesise a segment. - - Raised for both configuration faults (an unusable voice reference) and - provider faults (the upstream call errored or returned no audio), so the - renderer can fail the segment without unwrapping provider-specific - exceptions. - """ diff --git a/surfsense_backend/app/podcasts/tts/factory.py b/surfsense_backend/app/podcasts/tts/factory.py deleted file mode 100644 index 7b4a48adf..000000000 --- a/surfsense_backend/app/podcasts/tts/factory.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Resolve the configured :class:`TextToSpeech` as a process-wide singleton.""" - -from __future__ import annotations - -from functools import lru_cache - -from .port import TextToSpeech - -# Sentinel model string that selects the local Kokoro pipeline; anything else is -# treated as a LiteLLM-hosted model (``openai/...``, ``vertex_ai/...``, etc.). -KOKORO_SERVICE = "local/kokoro" - - -@lru_cache(maxsize=1) -def get_text_to_speech() -> TextToSpeech: - """Build the provider selected by ``TTS_SERVICE`` (adapters lazy-imported). - - Cached because the Kokoro adapter holds loaded pipelines that must be reused - across segments and requests rather than rebuilt per call. - """ - from app.config import config as app_config - - service = app_config.TTS_SERVICE - if not service: - raise ValueError("TTS_SERVICE is not configured") - - if service == KOKORO_SERVICE: - from .adapters.kokoro import KokoroTextToSpeech - - return KokoroTextToSpeech() - - from .adapters.litellm import LiteLlmTextToSpeech - - return LiteLlmTextToSpeech( - model=service, - api_base=app_config.TTS_SERVICE_API_BASE, - api_key=app_config.TTS_SERVICE_API_KEY, - ) diff --git a/surfsense_backend/app/podcasts/tts/port.py b/surfsense_backend/app/podcasts/tts/port.py deleted file mode 100644 index 604708260..000000000 --- a/surfsense_backend/app/podcasts/tts/port.py +++ /dev/null @@ -1,31 +0,0 @@ -"""The TTS contract: turn one segment of text into encoded audio.""" - -from __future__ import annotations - -from abc import ABC, abstractmethod - -from .audio import SynthesizedAudio -from .request import SynthesisRequest - - -class TextToSpeech(ABC): - """Synthesises a single segment; one implementation per provider. - - The contract is intentionally per-segment rather than per-episode: it keeps - each call independently cacheable and lets the renderer cap concurrency and - retry segments in isolation. Stitching segments into one file is the - renderer's job, not the provider's. - """ - - @property - @abstractmethod - def container(self) -> str: - """File extension/container this provider emits (e.g. ``"mp3"``).""" - - @abstractmethod - async def synthesize(self, request: SynthesisRequest) -> SynthesizedAudio: - """Voice ``request.text`` and return its encoded audio. - - Raises :class:`~app.podcasts.tts.errors.TextToSpeechError` on any - provider or configuration failure. - """ diff --git a/surfsense_backend/app/podcasts/tts/request.py b/surfsense_backend/app/podcasts/tts/request.py deleted file mode 100644 index 2cb5f6ec4..000000000 --- a/surfsense_backend/app/podcasts/tts/request.py +++ /dev/null @@ -1,22 +0,0 @@ -"""What the renderer hands a TTS provider to voice a single segment.""" - -from __future__ import annotations - -from collections.abc import Mapping -from dataclasses import dataclass -from typing import Any - -# A provider-native voice reference. OpenAI/Azure/Kokoro name a voice with a -# string; Vertex passes a mapping (``languageCode`` + ``name``). The catalog -# stores whichever shape the provider expects and we pass it through untouched. -VoiceRef = str | Mapping[str, Any] - - -@dataclass(frozen=True, slots=True) -class SynthesisRequest: - """One unit of speech to synthesise: the smallest cacheable render step.""" - - text: str - voice: VoiceRef - language: str - speed: float = 1.0 diff --git a/surfsense_backend/app/podcasts/voices/__init__.py b/surfsense_backend/app/podcasts/voices/__init__.py deleted file mode 100644 index ab1f8bbbf..000000000 --- a/surfsense_backend/app/podcasts/voices/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Voices: the catalog of selectable TTS voices and the active provider. - -Callers obtain the catalog via :func:`get_voice_catalog` and identify the -configured provider via :func:`provider_from_service`. -""" - -from __future__ import annotations - -from .catalog import VoiceCatalog, get_voice_catalog -from .preview import render_voice_preview -from .provider import TtsProvider, provider_from_service -from .voice import ANY_LANGUAGE, CatalogVoice, VoiceGender - -__all__ = [ - "ANY_LANGUAGE", - "CatalogVoice", - "TtsProvider", - "VoiceCatalog", - "VoiceGender", - "get_voice_catalog", - "provider_from_service", - "render_voice_preview", -] diff --git a/surfsense_backend/app/podcasts/voices/catalog.py b/surfsense_backend/app/podcasts/voices/catalog.py deleted file mode 100644 index c36313a0c..000000000 --- a/surfsense_backend/app/podcasts/voices/catalog.py +++ /dev/null @@ -1,51 +0,0 @@ -"""The voice catalog: look up and filter selectable voices. - -A :class:`VoiceCatalog` is the single source of truth for which voices exist. -Resolution uses it to pick defaults for a brief, the API exposes it as picker -options, and the renderer uses it to turn a stored ``voice_id`` back into the -provider-native reference. -""" - -from __future__ import annotations - -from collections.abc import Iterable -from functools import lru_cache - -from .data import AZURE_VOICES, KOKORO_VOICES, OPENAI_VOICES, VERTEX_VOICES -from .provider import TtsProvider -from .voice import CatalogVoice - - -class VoiceCatalog: - """An indexed, read-only collection of :class:`CatalogVoice`.""" - - def __init__(self, voices: Iterable[CatalogVoice]) -> None: - self._by_id: dict[str, CatalogVoice] = {} - self._by_provider: dict[TtsProvider, list[CatalogVoice]] = {} - for voice in voices: - if voice.voice_id in self._by_id: - raise ValueError(f"duplicate voice_id: {voice.voice_id}") - self._by_id[voice.voice_id] = voice - self._by_provider.setdefault(voice.provider, []).append(voice) - - def get(self, voice_id: str) -> CatalogVoice: - """Return the voice with ``voice_id`` or raise ``KeyError``.""" - return self._by_id[voice_id] - - def for_provider(self, provider: TtsProvider) -> list[CatalogVoice]: - """All voices offered by ``provider``, in catalog order.""" - return list(self._by_provider.get(provider, ())) - - def for_language(self, provider: TtsProvider, language: str) -> list[CatalogVoice]: - """``provider`` voices that can render ``language``, in catalog order.""" - return [v for v in self.for_provider(provider) if v.speaks(language)] - - def supports_language(self, provider: TtsProvider, language: str) -> bool: - """Whether ``provider`` has at least one voice for ``language``.""" - return any(v.speaks(language) for v in self.for_provider(provider)) - - -@lru_cache(maxsize=1) -def get_voice_catalog() -> VoiceCatalog: - """The process-wide catalog assembled from every provider's roster.""" - return VoiceCatalog((*KOKORO_VOICES, *OPENAI_VOICES, *AZURE_VOICES, *VERTEX_VOICES)) diff --git a/surfsense_backend/app/podcasts/voices/data/__init__.py b/surfsense_backend/app/podcasts/voices/data/__init__.py deleted file mode 100644 index 5316f10f6..000000000 --- a/surfsense_backend/app/podcasts/voices/data/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Static per-provider voice rosters that compose the catalog.""" - -from __future__ import annotations - -from .azure import AZURE_VOICES -from .kokoro import KOKORO_VOICES -from .openai import OPENAI_VOICES -from .vertex import VERTEX_VOICES - -__all__ = ["AZURE_VOICES", "KOKORO_VOICES", "OPENAI_VOICES", "VERTEX_VOICES"] diff --git a/surfsense_backend/app/podcasts/voices/data/azure.py b/surfsense_backend/app/podcasts/voices/data/azure.py deleted file mode 100644 index 104ab766d..000000000 --- a/surfsense_backend/app/podcasts/voices/data/azure.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Azure TTS voices, routed through the OpenAI-compatible voice names. - -The deployment fronts Azure with OpenAI-style voice names (matching the legacy -podcaster), so these mirror the OpenAI roster and, like it, speak any requested -language. -""" - -from __future__ import annotations - -from ..provider import TtsProvider -from ..voice import ANY_LANGUAGE, CatalogVoice, VoiceGender - - -def _voice(name: str, display: str, gender: VoiceGender) -> CatalogVoice: - return CatalogVoice( - voice_id=f"azure:{name}", - provider=TtsProvider.AZURE, - language=ANY_LANGUAGE, - display_name=display, - gender=gender, - native_ref=name, - ) - - -AZURE_VOICES: tuple[CatalogVoice, ...] = ( - _voice("alloy", "Alloy", VoiceGender.NEUTRAL), - _voice("echo", "Echo", VoiceGender.MALE), - _voice("fable", "Fable", VoiceGender.NEUTRAL), - _voice("onyx", "Onyx", VoiceGender.MALE), - _voice("nova", "Nova", VoiceGender.FEMALE), - _voice("shimmer", "Shimmer", VoiceGender.FEMALE), -) diff --git a/surfsense_backend/app/podcasts/voices/data/kokoro.py b/surfsense_backend/app/podcasts/voices/data/kokoro.py deleted file mode 100644 index 732dced23..000000000 --- a/surfsense_backend/app/podcasts/voices/data/kokoro.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Curated Kokoro voices, the local provider's multilingual roster. - -Kokoro voice names encode language and gender in their first two letters -(``a``=American English, ``b``=British, ``e``=Spanish, ``f``=French, -``h``=Hindi, ``i``=Italian, ``j``=Japanese, ``p``=Brazilian Portuguese, -``z``=Mandarin; second letter ``f``/``m`` = female/male). We carry at least one -male and one female voice per language so a two-speaker brief always has a -distinct pair. ``native_ref`` is the bare voice name Kokoro expects. - -Reference: https://huggingface.co/hexgrad/Kokoro-82M/tree/main/voices -""" - -from __future__ import annotations - -from ..provider import TtsProvider -from ..voice import CatalogVoice, VoiceGender - - -def _voice(name: str, language: str, display: str, gender: VoiceGender) -> CatalogVoice: - return CatalogVoice( - voice_id=f"kokoro:{name}", - provider=TtsProvider.KOKORO, - language=language, - display_name=display, - gender=gender, - native_ref=name, - ) - - -KOKORO_VOICES: tuple[CatalogVoice, ...] = ( - # American English - _voice("am_adam", "en-US", "Adam (US)", VoiceGender.MALE), - _voice("am_michael", "en-US", "Michael (US)", VoiceGender.MALE), - _voice("af_bella", "en-US", "Bella (US)", VoiceGender.FEMALE), - _voice("af_heart", "en-US", "Heart (US)", VoiceGender.FEMALE), - _voice("af_nicole", "en-US", "Nicole (US)", VoiceGender.FEMALE), - _voice("af_sarah", "en-US", "Sarah (US)", VoiceGender.FEMALE), - # British English - _voice("bm_george", "en-GB", "George (UK)", VoiceGender.MALE), - _voice("bm_lewis", "en-GB", "Lewis (UK)", VoiceGender.MALE), - _voice("bf_emma", "en-GB", "Emma (UK)", VoiceGender.FEMALE), - _voice("bf_isabella", "en-GB", "Isabella (UK)", VoiceGender.FEMALE), - # Spanish - _voice("em_alex", "es", "Alex (ES)", VoiceGender.MALE), - _voice("ef_dora", "es", "Dora (ES)", VoiceGender.FEMALE), - # French - _voice("ff_siwis", "fr", "Siwis (FR)", VoiceGender.FEMALE), - # Hindi - _voice("hm_omega", "hi", "Omega (HI)", VoiceGender.MALE), - _voice("hf_alpha", "hi", "Alpha (HI)", VoiceGender.FEMALE), - # Italian - _voice("im_nicola", "it", "Nicola (IT)", VoiceGender.MALE), - _voice("if_sara", "it", "Sara (IT)", VoiceGender.FEMALE), - # Japanese - _voice("jm_kumo", "ja", "Kumo (JA)", VoiceGender.MALE), - _voice("jf_alpha", "ja", "Alpha (JA)", VoiceGender.FEMALE), - # Brazilian Portuguese - _voice("pm_alex", "pt-BR", "Alex (BR)", VoiceGender.MALE), - _voice("pf_dora", "pt-BR", "Dora (BR)", VoiceGender.FEMALE), - # Mandarin Chinese - _voice("zm_yunxi", "zh", "Yunxi (ZH)", VoiceGender.MALE), - _voice("zf_xiaoxiao", "zh", "Xiaoxiao (ZH)", VoiceGender.FEMALE), -) diff --git a/surfsense_backend/app/podcasts/voices/data/openai.py b/surfsense_backend/app/podcasts/voices/data/openai.py deleted file mode 100644 index ce5c480c5..000000000 --- a/surfsense_backend/app/podcasts/voices/data/openai.py +++ /dev/null @@ -1,32 +0,0 @@ -"""OpenAI TTS voices: language-agnostic, so each speaks any requested language. - -OpenAI voices follow the language of the input text rather than being tied to a -locale, so they are tagged :data:`ANY_LANGUAGE` and match every brief. The -``native_ref`` is the plain voice name the API expects. -""" - -from __future__ import annotations - -from ..provider import TtsProvider -from ..voice import ANY_LANGUAGE, CatalogVoice, VoiceGender - - -def _voice(name: str, display: str, gender: VoiceGender) -> CatalogVoice: - return CatalogVoice( - voice_id=f"openai:{name}", - provider=TtsProvider.OPENAI, - language=ANY_LANGUAGE, - display_name=display, - gender=gender, - native_ref=name, - ) - - -OPENAI_VOICES: tuple[CatalogVoice, ...] = ( - _voice("alloy", "Alloy", VoiceGender.NEUTRAL), - _voice("echo", "Echo", VoiceGender.MALE), - _voice("fable", "Fable", VoiceGender.NEUTRAL), - _voice("onyx", "Onyx", VoiceGender.MALE), - _voice("nova", "Nova", VoiceGender.FEMALE), - _voice("shimmer", "Shimmer", VoiceGender.FEMALE), -) diff --git a/surfsense_backend/app/podcasts/voices/data/vertex.py b/surfsense_backend/app/podcasts/voices/data/vertex.py deleted file mode 100644 index 99477eb21..000000000 --- a/surfsense_backend/app/podcasts/voices/data/vertex.py +++ /dev/null @@ -1,81 +0,0 @@ -"""Vertex AI Studio voices: locale-specific, referenced by a mapping. - -Vertex voices are tied to a locale and named via a ``{languageCode, name}`` -mapping, which is exactly the ``native_ref`` the LiteLLM adapter forwards. The -values mirror the legacy podcaster's English Studio voices. -""" - -from __future__ import annotations - -from ..provider import TtsProvider -from ..voice import CatalogVoice, VoiceGender - - -def _voice( - key: str, - language: str, - locale: str, - name: str, - display: str, - gender: VoiceGender, -) -> CatalogVoice: - return CatalogVoice( - voice_id=f"vertex_ai:{key}", - provider=TtsProvider.VERTEX_AI, - language=language, - display_name=display, - gender=gender, - native_ref={"languageCode": locale, "name": name}, - ) - - -VERTEX_VOICES: tuple[CatalogVoice, ...] = ( - _voice( - "en-US-Studio-O", - "en-US", - "en-US", - "en-US-Studio-O", - "Studio O (US)", - VoiceGender.FEMALE, - ), - _voice( - "en-US-Studio-M", - "en-US", - "en-US", - "en-US-Studio-M", - "Studio M (US)", - VoiceGender.MALE, - ), - _voice( - "en-GB-Studio-A", - "en-GB", - "en-UK", - "en-UK-Studio-A", - "Studio A (UK)", - VoiceGender.FEMALE, - ), - _voice( - "en-GB-Studio-B", - "en-GB", - "en-UK", - "en-UK-Studio-B", - "Studio B (UK)", - VoiceGender.MALE, - ), - _voice( - "en-AU-Studio-A", - "en-AU", - "en-AU", - "en-AU-Studio-A", - "Studio A (AU)", - VoiceGender.FEMALE, - ), - _voice( - "en-AU-Studio-B", - "en-AU", - "en-AU", - "en-AU-Studio-B", - "Studio B (AU)", - VoiceGender.MALE, - ), -) diff --git a/surfsense_backend/app/podcasts/voices/preview.py b/surfsense_backend/app/podcasts/voices/preview.py deleted file mode 100644 index 868504a91..000000000 --- a/surfsense_backend/app/podcasts/voices/preview.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Audible previews so users pick voices by sound, not by name. - -A preview is a short sample sentence synthesised in the voice's own language. -Samples are served through the same content-addressed cache the renderer uses, -so each voice costs at most one synthesis per cache lifetime — repeat listens -while comparing voices are free. -""" - -from __future__ import annotations - -import tempfile -from pathlib import Path - -from app.podcasts.rendering.cache import SegmentCache -from app.podcasts.tts import SynthesisRequest, TextToSpeech - -from .voice import ANY_LANGUAGE, CatalogVoice - -# Previews are user-independent, so one rendered sample serves everyone. -PREVIEW_CACHE_ROOT = Path(tempfile.gettempdir()) / "surfsense_podcasts" / "previews" - -_FALLBACK_LANGUAGE = "en" - -# A voice previews best speaking its own language. -_SAMPLE_TEXTS = { - "en": "Hi there! This is how I sound when narrating your podcast.", - "es": "¡Hola! Así sueno cuando narro tu pódcast.", - "fr": "Bonjour ! Voici ma voix quand je raconte votre podcast.", - "hi": "नमस्ते! आपका पॉडकास्ट सुनाते समय मेरी आवाज़ ऐसी होती है।", - "it": "Ciao! Questa è la mia voce quando racconto il tuo podcast.", - "ja": "こんにちは。ポッドキャストをお届けするときの私の声です。", - "pt": "Olá! É assim que eu soo ao narrar o seu podcast.", - "zh": "你好!这就是我为你播报播客时的声音。", # noqa: RUF001 -} - -_CONTENT_TYPES = {"mp3": "audio/mpeg", "wav": "audio/wav"} - - -async def render_voice_preview( - voice: CatalogVoice, tts: TextToSpeech -) -> tuple[bytes, str]: - """Return ``(audio_bytes, content_type)`` for a sample spoken by ``voice``.""" - language = _FALLBACK_LANGUAGE if voice.language == ANY_LANGUAGE else voice.language - request = SynthesisRequest( - text=_sample_text(language), voice=voice.native_ref, language=language - ) - - cache = SegmentCache(PREVIEW_CACHE_ROOT) - key = cache.key(request) - cached = cache.get(key, tts.container) - if cached is not None: - return cached.read_bytes(), _content_type(tts.container) - - audio = await tts.synthesize(request) - cache.put(key, audio.container, audio.data) - return audio.data, _content_type(audio.container) - - -def _sample_text(language: str) -> str: - primary = language.split("-", 1)[0].strip().lower() - return _SAMPLE_TEXTS.get(primary, _SAMPLE_TEXTS[_FALLBACK_LANGUAGE]) - - -def _content_type(container: str) -> str: - return _CONTENT_TYPES.get(container, "application/octet-stream") diff --git a/surfsense_backend/app/podcasts/voices/provider.py b/surfsense_backend/app/podcasts/voices/provider.py deleted file mode 100644 index f57ae11cc..000000000 --- a/surfsense_backend/app/podcasts/voices/provider.py +++ /dev/null @@ -1,27 +0,0 @@ -"""The TTS providers we carry voices for, and how to name one from config.""" - -from __future__ import annotations - -from enum import StrEnum - - -class TtsProvider(StrEnum): - """A speech provider whose voices the catalog enumerates.""" - - KOKORO = "kokoro" - OPENAI = "openai" - AZURE = "azure" - VERTEX_AI = "vertex_ai" - - -def provider_from_service(service: str) -> TtsProvider: - """Map a ``TTS_SERVICE`` string to its provider. - - The config value is a LiteLLM-style ``provider/model`` string - (``openai/tts-1``, ``vertex_ai/...``) except for local Kokoro, which is - spelled ``local/kokoro``; both halves of that special case resolve here. - """ - prefix = service.split("/", 1)[0].strip().lower() - if prefix == "local": - return TtsProvider.KOKORO - return TtsProvider(prefix) diff --git a/surfsense_backend/app/podcasts/voices/voice.py b/surfsense_backend/app/podcasts/voices/voice.py deleted file mode 100644 index 6478f04b0..000000000 --- a/surfsense_backend/app/podcasts/voices/voice.py +++ /dev/null @@ -1,50 +0,0 @@ -"""A catalog voice: a stable id paired with its provider-native reference.""" - -from __future__ import annotations - -from dataclasses import dataclass -from enum import StrEnum - -from app.podcasts.tts import VoiceRef - -from .provider import TtsProvider - -# A voice that speaks whatever language the input text is in (e.g. OpenAI's -# voices), matched against every requested language. -ANY_LANGUAGE = "*" - - -class VoiceGender(StrEnum): - """Perceived voice gender, used to pick distinct voices per speaker.""" - - MALE = "male" - FEMALE = "female" - NEUTRAL = "neutral" - - -@dataclass(frozen=True, slots=True) -class CatalogVoice: - """One selectable voice. - - ``voice_id`` is the provider-prefixed, stable id stored on a speaker in the - brief (e.g. ``"kokoro:am_adam"``). ``native_ref`` is the untyped value the - TTS adapter passes to the provider — a string for most, a mapping for - Vertex — kept separate so renaming the catalog id never breaks synthesis. - """ - - voice_id: str - provider: TtsProvider - language: str - display_name: str - gender: VoiceGender - native_ref: VoiceRef - - def speaks(self, language: str) -> bool: - """Whether this voice can render ``language`` (primary subtag match).""" - if self.language == ANY_LANGUAGE: - return True - return _primary(self.language) == _primary(language) - - -def _primary(language: str) -> str: - return language.split("-", 1)[0].strip().lower() diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index a050651f6..5cc029884 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -4,7 +4,6 @@ from app.automations.api import router as automations_router from app.file_storage.api import router as file_storage_router from app.gateway import require_gateway_enabled from app.notifications.api import router as notifications_router -from app.podcasts.api import router as podcasts_router from .agent_action_log_route import router as agent_action_log_router from .agent_flags_route import router as agent_flags_router @@ -51,6 +50,7 @@ from .notes_routes import router as notes_router from .notion_add_connector_route import router as notion_add_connector_router from .obsidian_plugin_routes import router as obsidian_plugin_router from .onedrive_add_connector_route import router as onedrive_add_connector_router +from .podcasts_routes import router as podcasts_router from .prompts_routes import router as prompts_router from .public_chat_routes import router as public_chat_router from .rbac_routes import router as rbac_router diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py index 33caf8453..018234ad5 100644 --- a/surfsense_backend/app/routes/image_generation_routes.py +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -622,10 +622,11 @@ async def create_image_generation( detail={ "error_code": "premium_quota_exhausted", "usage_type": exc.usage_type, - "balance_micros": exc.balance_micros, + "used_micros": exc.used_micros, + "limit_micros": exc.limit_micros, "remaining_micros": exc.remaining_micros, "message": ( - "Out of credits for image generation. " + "Out of premium credits for image generation. " "Purchase additional credits or switch to a free model." ), }, diff --git a/surfsense_backend/app/routes/incentive_tasks_routes.py b/surfsense_backend/app/routes/incentive_tasks_routes.py index 1dae09a2d..496b07d06 100644 --- a/surfsense_backend/app/routes/incentive_tasks_routes.py +++ b/surfsense_backend/app/routes/incentive_tasks_routes.py @@ -1,6 +1,6 @@ """ Incentive Tasks API routes. -Allows users to complete tasks (like starring GitHub repo) to earn free credits. +Allows users to complete tasks (like starring GitHub repo) to earn free pages. Each task can only be completed once per user. """ @@ -42,21 +42,21 @@ async def get_incentive_tasks( # Build task list with completion status tasks = [] - total_credit_micros_earned = 0 + total_pages_earned = 0 for task_type, config in INCENTIVE_TASKS_CONFIG.items(): completed_task = completed_tasks.get(task_type) is_completed = completed_task is not None if is_completed: - total_credit_micros_earned += completed_task.credit_micros_awarded + total_pages_earned += completed_task.pages_awarded tasks.append( IncentiveTaskInfo( task_type=task_type, title=config["title"], description=config["description"], - credit_micros_reward=config["credit_micros_reward"], + pages_reward=config["pages_reward"], action_url=config["action_url"], completed=is_completed, completed_at=completed_task.completed_at if completed_task else None, @@ -65,7 +65,7 @@ async def get_incentive_tasks( return IncentiveTasksResponse( tasks=tasks, - total_credit_micros_earned=total_credit_micros_earned, + total_pages_earned=total_pages_earned, ) @@ -79,10 +79,10 @@ async def complete_task( session: AsyncSession = Depends(get_async_session), ) -> CompleteTaskResponse | TaskAlreadyCompletedResponse: """ - Mark an incentive task as completed and award credit to the user. + Mark an incentive task as completed and award pages to the user. Each task can only be completed once. If the task was already completed, - returns the existing completion information without awarding additional credit. + returns the existing completion information without awarding additional pages. """ # Validate task type exists in config task_config = INCENTIVE_TASKS_CONFIG.get(task_type) @@ -109,23 +109,25 @@ async def complete_task( ) # Create the task completion record - credit_micros_reward = task_config["credit_micros_reward"] + pages_reward = task_config["pages_reward"] new_task = UserIncentiveTask( user_id=user.id, task_type=task_type, - credit_micros_awarded=credit_micros_reward, + pages_awarded=pages_reward, ) session.add(new_task) - # Add the reward directly to the user's spendable wallet balance. - user.credit_micros_balance = user.credit_micros_balance + credit_micros_reward + # pages_used can exceed pages_limit when a document's final page count is + # determined after processing. Base the new limit on the higher of the two + # so the rewarded pages are fully usable above the current high-water mark. + user.pages_limit = max(user.pages_used, user.pages_limit) + pages_reward await session.commit() await session.refresh(user) return CompleteTaskResponse( success=True, - message=f"Task completed! You earned ${credit_micros_reward / 1_000_000:.2f} of credit.", - credit_micros_awarded=credit_micros_reward, - new_balance_micros=user.credit_micros_balance, + message=f"Task completed! You earned {pages_reward} pages.", + pages_awarded=pages_reward, + new_pages_limit=user.pages_limit, ) diff --git a/surfsense_backend/app/routes/podcasts_routes.py b/surfsense_backend/app/routes/podcasts_routes.py new file mode 100644 index 000000000..f991f698f --- /dev/null +++ b/surfsense_backend/app/routes/podcasts_routes.py @@ -0,0 +1,211 @@ +""" +Podcast routes for CRUD operations and audio streaming. + +These routes support the podcast generation feature in new-chat. +Frontend polls GET /podcasts/{podcast_id} to check status field. +""" + +import os +from pathlib import Path + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import StreamingResponse +from sqlalchemy import select +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import ( + Permission, + Podcast, + SearchSpace, + SearchSpaceMembership, + User, + get_async_session, +) +from app.schemas import PodcastRead +from app.users import current_active_user +from app.utils.rbac import check_permission + +router = APIRouter() + + +@router.get("/podcasts", response_model=list[PodcastRead]) +async def read_podcasts( + skip: int = 0, + limit: int = 100, + search_space_id: int | None = None, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + List podcasts the user has access to. + Requires PODCASTS_READ permission for the search space(s). + """ + if skip < 0 or limit < 1: + raise HTTPException(status_code=400, detail="Invalid pagination parameters") + try: + if search_space_id is not None: + # Check permission for specific search space + await check_permission( + session, + user, + search_space_id, + Permission.PODCASTS_READ.value, + "You don't have permission to read podcasts in this search space", + ) + result = await session.execute( + select(Podcast) + .filter(Podcast.search_space_id == search_space_id) + .offset(skip) + .limit(limit) + ) + else: + # Get podcasts from all search spaces user has membership in + result = await session.execute( + select(Podcast) + .join(SearchSpace) + .join(SearchSpaceMembership) + .filter(SearchSpaceMembership.user_id == user.id) + .offset(skip) + .limit(limit) + ) + return result.scalars().all() + except HTTPException: + raise + except SQLAlchemyError: + raise HTTPException( + status_code=500, detail="Database error occurred while fetching podcasts" + ) from None + + +@router.get("/podcasts/{podcast_id}", response_model=PodcastRead) +async def read_podcast( + podcast_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Get a specific podcast by ID. + + Requires authentication with PODCASTS_READ permission. + For public podcast access, use /public/{share_token}/podcasts/{podcast_id}/stream + """ + try: + result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id)) + podcast = result.scalars().first() + + if not podcast: + raise HTTPException( + status_code=404, + detail="Podcast not found", + ) + + await check_permission( + session, + user, + podcast.search_space_id, + Permission.PODCASTS_READ.value, + "You don't have permission to read podcasts in this search space", + ) + + return PodcastRead.from_orm_with_entries(podcast) + except HTTPException as he: + raise he + except SQLAlchemyError: + raise HTTPException( + status_code=500, detail="Database error occurred while fetching podcast" + ) from None + + +@router.delete("/podcasts/{podcast_id}", response_model=dict) +async def delete_podcast( + podcast_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Delete a podcast. + Requires PODCASTS_DELETE permission for the search space. + """ + try: + result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id)) + db_podcast = result.scalars().first() + + if not db_podcast: + raise HTTPException(status_code=404, detail="Podcast not found") + + # Check permission for the search space + await check_permission( + session, + user, + db_podcast.search_space_id, + Permission.PODCASTS_DELETE.value, + "You don't have permission to delete podcasts in this search space", + ) + + await session.delete(db_podcast) + await session.commit() + return {"message": "Podcast deleted successfully"} + except HTTPException as he: + raise he + except SQLAlchemyError: + await session.rollback() + raise HTTPException( + status_code=500, detail="Database error occurred while deleting podcast" + ) from None + + +@router.get("/podcasts/{podcast_id}/stream") +@router.get("/podcasts/{podcast_id}/audio") +async def stream_podcast( + podcast_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Stream a podcast audio file. + + Requires authentication with PODCASTS_READ permission. + For public podcast access, use /public/{share_token}/podcasts/{podcast_id}/stream + + Note: Both /stream and /audio endpoints are supported for compatibility. + """ + try: + result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id)) + podcast = result.scalars().first() + + if not podcast: + raise HTTPException(status_code=404, detail="Podcast not found") + + await check_permission( + session, + user, + podcast.search_space_id, + Permission.PODCASTS_READ.value, + "You don't have permission to access podcasts in this search space", + ) + + file_path = podcast.file_location + + if not file_path or not os.path.isfile(file_path): + raise HTTPException(status_code=404, detail="Podcast audio file not found") + + def iterfile(): + with open(file_path, mode="rb") as file_like: + yield from file_like + + return StreamingResponse( + iterfile(), + media_type="audio/mpeg", + headers={ + "Accept-Ranges": "bytes", + "Content-Disposition": f"inline; filename={Path(file_path).name}", + }, + ) + + except HTTPException as he: + raise he + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Error streaming podcast: {e!s}" + ) from e diff --git a/surfsense_backend/app/routes/public_chat_routes.py b/surfsense_backend/app/routes/public_chat_routes.py index 53f4c2651..3181e117c 100644 --- a/surfsense_backend/app/routes/public_chat_routes.py +++ b/surfsense_backend/app/routes/public_chat_routes.py @@ -99,17 +99,6 @@ async def stream_public_podcast( if not podcast_info: raise HTTPException(status_code=404, detail="Podcast not found") - storage_key = podcast_info.get("storage_key") - if storage_key: - from app.file_storage.factory import get_storage_backend - - return StreamingResponse( - get_storage_backend().open_stream(storage_key), - media_type="audio/mpeg", - headers={"Accept-Ranges": "bytes"}, - ) - - # Legacy fallback for snapshots taken before the storage migration. file_path = podcast_info.get("file_path") if not file_path or not os.path.isfile(file_path): diff --git a/surfsense_backend/app/routes/stripe_routes.py b/surfsense_backend/app/routes/stripe_routes.py index 23dce58cd..fc5fded84 100644 --- a/surfsense_backend/app/routes/stripe_routes.py +++ b/surfsense_backend/app/routes/stripe_routes.py @@ -1,10 +1,4 @@ -"""Stripe routes for the unified credit wallet. - -Buying credit packs ($1 == 1_000_000 micro-USD by default) tops up -``user.credit_micros_balance``. The same balance is debited for ETL page -processing and premium model calls. Legacy page-pack buying has been removed; -``page_purchases`` history is still readable via ``GET /stripe/purchases``. -""" +"""Stripe routes for pay-as-you-go page purchases.""" from __future__ import annotations @@ -20,24 +14,24 @@ from stripe import SignatureVerificationError, StripeClient, StripeError from app.config import config from app.db import ( - CreditPurchase, - CreditPurchaseStatus, PagePurchase, + PagePurchaseStatus, + PremiumTokenPurchase, + PremiumTokenPurchaseStatus, User, get_async_session, ) from app.schemas.stripe import ( - AutoReloadSettingsResponse, - CreateAutoReloadSetupSessionRequest, - CreateAutoReloadSetupSessionResponse, - CreateCreditCheckoutSessionRequest, - CreateCreditCheckoutSessionResponse, - CreditPurchaseHistoryResponse, - CreditStripeStatusResponse, + CreateCheckoutSessionRequest, + CreateCheckoutSessionResponse, + CreateTokenCheckoutSessionRequest, + CreateTokenCheckoutSessionResponse, FinalizeCheckoutResponse, PagePurchaseHistoryResponse, + StripeStatusResponse, StripeWebhookResponse, - UpdateAutoReloadSettingsRequest, + TokenPurchaseHistoryResponse, + TokenStripeStatusResponse, ) from app.users import current_active_user @@ -56,11 +50,11 @@ def get_stripe_client() -> StripeClient: return StripeClient(config.STRIPE_SECRET_KEY) -def _ensure_credit_buying_enabled() -> None: - if not config.STRIPE_CREDIT_BUYING_ENABLED: +def _ensure_page_buying_enabled() -> None: + if not config.STRIPE_PAGE_BUYING_ENABLED: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Credit purchases are temporarily unavailable.", + detail="Page purchases are temporarily unavailable.", ) @@ -85,62 +79,13 @@ def _get_checkout_urls(search_space_id: int) -> tuple[str, str]: return success_url, cancel_url -def _get_required_credit_price_id() -> str: - if not config.STRIPE_CREDIT_PRICE_ID: +def _get_required_stripe_price_id() -> str: + if not config.STRIPE_PRICE_ID: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="STRIPE_CREDIT_PRICE_ID is not configured.", + detail="STRIPE_PRICE_ID is not configured.", ) - return config.STRIPE_CREDIT_PRICE_ID - - -def _ensure_auto_reload_enabled() -> None: - if not config.AUTO_RELOAD_ENABLED: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Auto-reload is not available.", - ) - - -async def _get_or_create_stripe_customer( - stripe_client: StripeClient, db_session: AsyncSession, user: User -) -> str: - """Return the user's Stripe Customer id, creating + persisting one if needed. - - A Customer object is required to save and later reuse a card off-session - (Stripe: save-and-reuse). New checkouts attach to this customer so the same - saved card powers both manual top-ups and auto-reload. - """ - if user.stripe_customer_id: - return user.stripe_customer_id - - customer = stripe_client.v1.customers.create( - params={ - "email": user.email, - "metadata": {"user_id": str(user.id)}, - } - ) - customer_id = str(customer.id) - - # Persist on the live row with a lock to avoid two concurrent checkouts - # creating duplicate customers. - locked = ( - ( - await db_session.execute( - select(User).where(User.id == user.id).with_for_update(of=User) - ) - ) - .unique() - .scalar_one_or_none() - ) - if locked is not None: - if locked.stripe_customer_id: - # Another request won the race; reuse theirs. - customer_id = locked.stripe_customer_id - else: - locked.stripe_customer_id = customer_id - await db_session.commit() - return customer_id + return config.STRIPE_PRICE_ID def _normalize_optional_string(value: Any) -> str | None: @@ -165,9 +110,14 @@ def _get_metadata(checkout_session: Any) -> dict[str, str]: if metadata is None: return {} + # 1. Plain dict (older SDKs that subclassed dict, JSON-decoded events + # in tests, etc.). if isinstance(metadata, dict): return {str(k): str(v) for k, v in metadata.items()} + # 2. Modern Stripe SDK: every ``StripeObject`` has ``to_dict()``. + # ``recursive=False`` is correct because Stripe metadata values + # are always primitive strings. to_dict = getattr(metadata, "to_dict", None) if callable(to_dict): try: @@ -180,6 +130,8 @@ def _get_metadata(checkout_session: Any) -> dict[str, str]: getattr(checkout_session, "id", "?"), ) + # 3. Last-resort: read the SDK's private ``_data`` backing dict. + # Stable across stripe-python 6.x -> 15.x. inner = getattr(metadata, "_data", None) if isinstance(inner, dict): return {str(k): str(v) for k, v in inner.items()} @@ -192,90 +144,120 @@ def _get_metadata(checkout_session: Any) -> dict[str, str]: return {} -# Canonical purchase_type metadata value is ``credits``. ``premium_tokens`` and -# ``premium_credit`` were emitted by earlier releases so they're still accepted -# on the read side for any in-flight checkout sessions. -_PURCHASE_TYPE_CREDIT_VALUES = frozenset( - {"credits", "premium_tokens", "premium_credit"} -) +# Canonical purchase_type metadata values. ``premium_credit`` was emitted +# by an earlier release of ``create_token_checkout_session`` so it's still +# accepted on the read side for backward compat with in-flight sessions. +_PURCHASE_TYPE_TOKEN_VALUES = frozenset({"premium_tokens", "premium_credit"}) -def _is_credit_purchase(metadata: dict[str, str]) -> bool: - """Return True for a credit purchase (default for all live checkouts).""" - return metadata.get("purchase_type", "credits") in _PURCHASE_TYPE_CREDIT_VALUES +def _is_token_purchase(metadata: dict[str, str]) -> bool: + """Return True for premium-credit (a.k.a. premium_token) purchases.""" + return metadata.get("purchase_type", "page_packs") in _PURCHASE_TYPE_TOKEN_VALUES -async def _mark_credit_purchase_failed( +async def _get_or_create_purchase_from_checkout_session( + db_session: AsyncSession, + checkout_session: Any, +) -> PagePurchase | None: + """Look up a PagePurchase by checkout session ID (with FOR UPDATE lock). + + If the row doesn't exist yet (e.g. the webhook arrived before the API + response committed), create one from the Stripe session metadata. + """ + checkout_session_id = str(checkout_session.id) + purchase = ( + await db_session.execute( + select(PagePurchase) + .where(PagePurchase.stripe_checkout_session_id == checkout_session_id) + .with_for_update() + ) + ).scalar_one_or_none() + if purchase is not None: + return purchase + + metadata = _get_metadata(checkout_session) + user_id = metadata.get("user_id") + quantity = int(metadata.get("quantity", "0")) + pages_per_unit = int(metadata.get("pages_per_unit", "0")) + + if not user_id or quantity <= 0 or pages_per_unit <= 0: + logger.error( + "Skipping Stripe fulfillment for session %s due to incomplete metadata: %s", + checkout_session_id, + metadata, + ) + return None + + purchase = PagePurchase( + user_id=uuid.UUID(user_id), + stripe_checkout_session_id=checkout_session_id, + stripe_payment_intent_id=_normalize_optional_string( + getattr(checkout_session, "payment_intent", None) + ), + quantity=quantity, + pages_granted=quantity * pages_per_unit, + amount_total=getattr(checkout_session, "amount_total", None), + currency=getattr(checkout_session, "currency", None), + status=PagePurchaseStatus.PENDING, + ) + db_session.add(purchase) + await db_session.flush() + return purchase + + +async def _mark_purchase_failed( db_session: AsyncSession, checkout_session_id: str ) -> StripeWebhookResponse: purchase = ( await db_session.execute( - select(CreditPurchase) - .where(CreditPurchase.stripe_checkout_session_id == checkout_session_id) + select(PagePurchase) + .where(PagePurchase.stripe_checkout_session_id == checkout_session_id) .with_for_update() ) ).scalar_one_or_none() - if purchase is not None and purchase.status == CreditPurchaseStatus.PENDING: - purchase.status = CreditPurchaseStatus.FAILED + if purchase is not None and purchase.status == PagePurchaseStatus.PENDING: + purchase.status = PagePurchaseStatus.FAILED await db_session.commit() return StripeWebhookResponse() -async def _fulfill_completed_credit_purchase( - db_session: AsyncSession, checkout_session: Any +async def _mark_token_purchase_failed( + db_session: AsyncSession, checkout_session_id: str ) -> StripeWebhookResponse: - """Grant credit to the user after a confirmed Stripe payment. - - Uses ``SELECT ... FOR UPDATE`` on both the CreditPurchase and User rows to - prevent double-granting when Stripe retries the webhook concurrently. - """ - checkout_session_id = str(checkout_session.id) purchase = ( await db_session.execute( - select(CreditPurchase) - .where(CreditPurchase.stripe_checkout_session_id == checkout_session_id) + select(PremiumTokenPurchase) + .where( + PremiumTokenPurchase.stripe_checkout_session_id == checkout_session_id + ) .with_for_update() ) ).scalar_one_or_none() + if purchase is not None and purchase.status == PremiumTokenPurchaseStatus.PENDING: + purchase.status = PremiumTokenPurchaseStatus.FAILED + await db_session.commit() + + return StripeWebhookResponse() + + +async def _fulfill_completed_purchase( + db_session: AsyncSession, checkout_session: Any +) -> StripeWebhookResponse: + """Grant pages to the user after a confirmed Stripe payment. + + Uses SELECT ... FOR UPDATE on both the PagePurchase and User rows to + prevent double-granting when Stripe retries the webhook concurrently. + """ + purchase = await _get_or_create_purchase_from_checkout_session( + db_session, checkout_session + ) if purchase is None: - metadata = _get_metadata(checkout_session) - user_id = metadata.get("user_id") - quantity = int(metadata.get("quantity", "0")) - # Read the new metadata key first, fall back to legacy ones so - # in-flight checkout sessions created before the rename still fulfil. - credit_micros_per_unit = int( - metadata.get("credit_micros_per_unit") - or metadata.get("tokens_per_unit", "0") - ) + return StripeWebhookResponse() - if not user_id or quantity <= 0 or credit_micros_per_unit <= 0: - logger.error( - "Skipping credit fulfillment for session %s: incomplete metadata %s", - checkout_session_id, - metadata, - ) - return StripeWebhookResponse() - - purchase = CreditPurchase( - user_id=uuid.UUID(user_id), - stripe_checkout_session_id=checkout_session_id, - stripe_payment_intent_id=_normalize_optional_string( - getattr(checkout_session, "payment_intent", None) - ), - quantity=quantity, - credit_micros_granted=quantity * credit_micros_per_unit, - amount_total=getattr(checkout_session, "amount_total", None), - currency=getattr(checkout_session, "currency", None), - source="checkout", - status=CreditPurchaseStatus.PENDING, - ) - db_session.add(purchase) - await db_session.flush() - - if purchase.status == CreditPurchaseStatus.COMPLETED: + if purchase.status == PagePurchaseStatus.COMPLETED: return StripeWebhookResponse() user = ( @@ -289,188 +271,132 @@ async def _fulfill_completed_credit_purchase( ) if user is None: logger.error( - "Skipping credit fulfillment for session %s: user %s not found", + "Skipping Stripe fulfillment for session %s because user %s was not found.", purchase.stripe_checkout_session_id, purchase.user_id, ) return StripeWebhookResponse() - purchase.status = CreditPurchaseStatus.COMPLETED + purchase.status = PagePurchaseStatus.COMPLETED purchase.completed_at = datetime.now(UTC) purchase.amount_total = getattr(checkout_session, "amount_total", None) purchase.currency = getattr(checkout_session, "currency", None) purchase.stripe_payment_intent_id = _normalize_optional_string( getattr(checkout_session, "payment_intent", None) ) - # Add the granted micro-USD directly to the spendable wallet balance. - user.credit_micros_balance = ( - user.credit_micros_balance + purchase.credit_micros_granted - ) + # pages_used can exceed pages_limit when a document's final page count is + # determined after processing. Base the new limit on the higher of the two + # so the purchased pages are fully usable above the current high-water mark. + user.pages_limit = max(user.pages_used, user.pages_limit) + purchase.pages_granted await db_session.commit() return StripeWebhookResponse() -async def _handle_setup_session_completed( - stripe_client: StripeClient, - db_session: AsyncSession, - checkout_session: Any, +async def _fulfill_completed_token_purchase( + db_session: AsyncSession, checkout_session: Any ) -> StripeWebhookResponse: - """Persist the saved card from a completed ``mode=setup`` checkout session. - - The setup session saves a card on the customer (Stripe save-and-reuse). We - pull the resulting payment method off the SetupIntent and store it as the - user's ``auto_reload_payment_method_id`` so the off-session charge can use - it. Auto-reload itself is only armed once the user enables it via the - settings endpoint. - """ - metadata = _get_metadata(checkout_session) - user_id = metadata.get("user_id") - if not user_id: - logger.warning( - "Setup session %s completed without user_id metadata", - getattr(checkout_session, "id", "?"), - ) - return StripeWebhookResponse() - - setup_intent_id = _normalize_optional_string( - getattr(checkout_session, "setup_intent", None) - ) - payment_method_id: str | None = None - if setup_intent_id: - try: - setup_intent = stripe_client.v1.setup_intents.retrieve(setup_intent_id) - payment_method_id = _normalize_optional_string( - getattr(setup_intent, "payment_method", None) + """Grant premium tokens to the user after a confirmed Stripe payment.""" + checkout_session_id = str(checkout_session.id) + purchase = ( + await db_session.execute( + select(PremiumTokenPurchase) + .where( + PremiumTokenPurchase.stripe_checkout_session_id == checkout_session_id ) - except StripeError: - logger.exception( - "Failed to retrieve setup intent %s for session %s", - setup_intent_id, - getattr(checkout_session, "id", "?"), - ) - - if not payment_method_id: - logger.warning( - "Setup session %s completed without a payment method", - getattr(checkout_session, "id", "?"), + .with_for_update() ) + ).scalar_one_or_none() + + if purchase is None: + metadata = _get_metadata(checkout_session) + user_id = metadata.get("user_id") + quantity = int(metadata.get("quantity", "0")) + # Read the new metadata key first, fall back to the legacy one so + # in-flight checkout sessions created before the cost-credits + # release still fulfil correctly (the unit is numerically the + # same: $1 buys 1_000_000 micro-USD == 1_000_000 tokens). + credit_micros_per_unit = int( + metadata.get("credit_micros_per_unit") + or metadata.get("tokens_per_unit", "0") + ) + + if not user_id or quantity <= 0 or credit_micros_per_unit <= 0: + logger.error( + "Skipping token fulfillment for session %s: incomplete metadata %s", + checkout_session_id, + metadata, + ) + return StripeWebhookResponse() + + purchase = PremiumTokenPurchase( + user_id=uuid.UUID(user_id), + stripe_checkout_session_id=checkout_session_id, + stripe_payment_intent_id=_normalize_optional_string( + getattr(checkout_session, "payment_intent", None) + ), + quantity=quantity, + credit_micros_granted=quantity * credit_micros_per_unit, + amount_total=getattr(checkout_session, "amount_total", None), + currency=getattr(checkout_session, "currency", None), + status=PremiumTokenPurchaseStatus.PENDING, + ) + db_session.add(purchase) + await db_session.flush() + + if purchase.status == PremiumTokenPurchaseStatus.COMPLETED: return StripeWebhookResponse() user = ( ( await db_session.execute( - select(User) - .where(User.id == uuid.UUID(user_id)) - .with_for_update(of=User) + select(User).where(User.id == purchase.user_id).with_for_update(of=User) ) ) .unique() .scalar_one_or_none() ) if user is None: + logger.error( + "Skipping token fulfillment for session %s: user %s not found", + purchase.stripe_checkout_session_id, + purchase.user_id, + ) return StripeWebhookResponse() - customer_id = _normalize_optional_string( - getattr(checkout_session, "customer", None) + purchase.status = PremiumTokenPurchaseStatus.COMPLETED + purchase.completed_at = datetime.now(UTC) + purchase.amount_total = getattr(checkout_session, "amount_total", None) + purchase.currency = getattr(checkout_session, "currency", None) + purchase.stripe_payment_intent_id = _normalize_optional_string( + getattr(checkout_session, "payment_intent", None) + ) + # Top up the user's credit balance by the granted micro-USD amount. + # ``max(used, limit)`` clamps the case where the legacy code wrote a + # used value above the limit (e.g. underbilling rounding) so adding + # ``credit_micros_granted`` always lifts the limit by the full pack + # size rather than disappearing into past overuse. + user.premium_credit_micros_limit = ( + max(user.premium_credit_micros_used, user.premium_credit_micros_limit) + + purchase.credit_micros_granted ) - if customer_id and not user.stripe_customer_id: - user.stripe_customer_id = customer_id - user.auto_reload_payment_method_id = payment_method_id - await db_session.commit() - - # Make this the customer's default for future off-session charges. - if user.stripe_customer_id: - try: - stripe_client.v1.customers.update( - user.stripe_customer_id, - params={ - "invoice_settings": {"default_payment_method": payment_method_id} - }, - ) - except StripeError: - logger.warning( - "Failed to set default payment method for customer %s", - user.stripe_customer_id, - exc_info=True, - ) - - return StripeWebhookResponse() - - -async def _reconcile_auto_reload_payment_intent( - db_session: AsyncSession, - payment_intent: Any, - *, - succeeded: bool, -) -> StripeWebhookResponse: - """Backstop for the off-session auto-reload charge via webhook. - - The Celery task confirms the PaymentIntent synchronously and grants credit - inline, but the ``payment_intent.succeeded`` / ``payment_intent.payment_failed`` - webhook acts as a safety net. We locate the matching ``auto_reload`` - CreditPurchase by payment-intent id and only transition PENDING rows so we - never double-grant. - """ - payment_intent_id = str(payment_intent.id) - purchase = ( - await db_session.execute( - select(CreditPurchase) - .where(CreditPurchase.stripe_payment_intent_id == payment_intent_id) - .with_for_update() - ) - ).scalar_one_or_none() - - if purchase is None or purchase.status != CreditPurchaseStatus.PENDING: - return StripeWebhookResponse() - - if succeeded: - user = ( - ( - await db_session.execute( - select(User) - .where(User.id == purchase.user_id) - .with_for_update(of=User) - ) - ) - .unique() - .scalar_one_or_none() - ) - if user is None: - return StripeWebhookResponse() - purchase.status = CreditPurchaseStatus.COMPLETED - purchase.completed_at = datetime.now(UTC) - user.credit_micros_balance = ( - user.credit_micros_balance + purchase.credit_micros_granted - ) - else: - purchase.status = CreditPurchaseStatus.FAILED await db_session.commit() return StripeWebhookResponse() -@router.post( - "/create-credit-checkout-session", - response_model=CreateCreditCheckoutSessionResponse, -) -async def create_credit_checkout_session( - body: CreateCreditCheckoutSessionRequest, +@router.post("/create-checkout-session", response_model=CreateCheckoutSessionResponse) +async def create_checkout_session( + body: CreateCheckoutSessionRequest, user: User = Depends(current_active_user), db_session: AsyncSession = Depends(get_async_session), -) -> CreateCreditCheckoutSessionResponse: - """Create a Stripe Checkout Session for buying credit packs. - - Each pack grants ``STRIPE_CREDIT_MICROS_PER_UNIT`` micro-USD of credit - (default 1_000_000 = $1.00). The balance is debited at the actual provider - cost reported by LiteLLM (premium calls) or ``MICROS_PER_PAGE`` per page - (ETL), so $1 of credit always buys $1 worth of usage at cost. - """ - _ensure_credit_buying_enabled() +) -> CreateCheckoutSessionResponse: + """Create a Stripe Checkout Session for buying page packs.""" + _ensure_page_buying_enabled() stripe_client = get_stripe_client() - price_id = _get_required_credit_price_id() + price_id = _get_required_stripe_price_id() success_url, cancel_url = _get_checkout_urls(body.search_space_id) - credit_micros_granted = body.quantity * config.STRIPE_CREDIT_MICROS_PER_UNIT + pages_granted = body.quantity * config.STRIPE_PAGES_PER_UNIT try: checkout_session = stripe_client.v1.checkout.sessions.create( @@ -489,14 +415,14 @@ async def create_credit_checkout_session( "metadata": { "user_id": str(user.id), "quantity": str(body.quantity), - "credit_micros_per_unit": str(config.STRIPE_CREDIT_MICROS_PER_UNIT), - "purchase_type": "credits", + "pages_per_unit": str(config.STRIPE_PAGES_PER_UNIT), + "purchase_type": "page_packs", }, } ) except StripeError as exc: logger.exception( - "Failed to create credit checkout session for user %s", user.id + "Failed to create Stripe checkout session for user %s", user.id ) raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, @@ -511,23 +437,28 @@ async def create_credit_checkout_session( ) db_session.add( - CreditPurchase( + PagePurchase( user_id=user.id, stripe_checkout_session_id=str(checkout_session.id), stripe_payment_intent_id=_normalize_optional_string( getattr(checkout_session, "payment_intent", None) ), quantity=body.quantity, - credit_micros_granted=credit_micros_granted, + pages_granted=pages_granted, amount_total=getattr(checkout_session, "amount_total", None), currency=getattr(checkout_session, "currency", None), - source="checkout", - status=CreditPurchaseStatus.PENDING, + status=PagePurchaseStatus.PENDING, ) ) await db_session.commit() - return CreateCreditCheckoutSessionResponse(checkout_url=checkout_url) + return CreateCheckoutSessionResponse(checkout_url=checkout_url) + + +@router.get("/status", response_model=StripeStatusResponse) +async def get_stripe_status() -> StripeStatusResponse: + """Return page-buying availability for frontend feature gating.""" + return StripeStatusResponse(page_buying_enabled=config.STRIPE_PAGE_BUYING_ENABLED) @router.post("/webhook", response_model=StripeWebhookResponse) @@ -535,7 +466,7 @@ async def stripe_webhook( request: Request, db_session: AsyncSession = Depends(get_async_session), ) -> StripeWebhookResponse: - """Handle Stripe webhooks and grant purchased credit after payment.""" + """Handle Stripe webhooks and grant purchased pages after payment.""" if not config.STRIPE_WEBHOOK_SECRET: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, @@ -587,37 +518,12 @@ async def stripe_webhook( ) return StripeWebhookResponse() - # mode=setup sessions carry no line items / payment; they save a - # card for off-session auto-reload. - if getattr(checkout_session, "mode", None) == "setup": - return await _handle_setup_session_completed( - stripe_client, db_session, checkout_session - ) - metadata = _get_metadata(checkout_session) - if _is_credit_purchase(metadata): - return await _fulfill_completed_credit_purchase( + if _is_token_purchase(metadata): + return await _fulfill_completed_token_purchase( db_session, checkout_session ) - # Legacy page-pack purchase: page buying is removed, so log and - # ignore rather than fulfilling. - logger.info( - "Ignoring non-credit checkout session %s (purchase_type=%s); " - "page buying is removed.", - getattr(checkout_session, "id", "?"), - metadata.get("purchase_type"), - ) - return StripeWebhookResponse() - - if event.type == "payment_intent.succeeded": - return await _reconcile_auto_reload_payment_intent( - db_session, event.data.object, succeeded=True - ) - - if event.type == "payment_intent.payment_failed": - return await _reconcile_auto_reload_payment_intent( - db_session, event.data.object, succeeded=False - ) + return await _fulfill_completed_purchase(db_session, checkout_session) if event.type in { "checkout.session.async_payment_failed", @@ -625,12 +531,16 @@ async def stripe_webhook( }: checkout_session = event.data.object metadata = _get_metadata(checkout_session) - if _is_credit_purchase(metadata): - return await _mark_credit_purchase_failed( + if _is_token_purchase(metadata): + return await _mark_token_purchase_failed( db_session, str(checkout_session.id) ) - return StripeWebhookResponse() + return await _mark_purchase_failed(db_session, str(checkout_session.id)) except Exception: + # Re-raise so FastAPI returns 500 and Stripe retries this delivery. + # Logging here gives us a structured trail with event id + type so + # future webhook bugs surface immediately in the logs without + # having to grep by request_id. logger.exception( "Stripe webhook handler failed for event id=%s type=%s — Stripe will retry", getattr(event, "id", "?"), @@ -647,17 +557,24 @@ async def finalize_checkout( user: User = Depends(current_active_user), db_session: AsyncSession = Depends(get_async_session), ) -> FinalizeCheckoutResponse: - """Synchronously fulfil a credit checkout session from the success page. + """Synchronously fulfil a checkout session from the success page. Solves the webhook-vs-redirect race: the user lands on ``/dashboard/<id>/purchase-success?session_id=cs_...`` typically a - few hundred ms after paying, but Stripe's ``checkout.session.completed`` - webhook can take 5-30s+ to arrive. Calling this endpoint on success-page - mount fulfils the purchase immediately via the same idempotent helper the - webhook uses. + few hundred ms after paying, but Stripe's + ``checkout.session.completed`` webhook can take 5-30s+ to arrive. + Calling this endpoint on success-page mount fulfils the purchase + immediately by retrieving the session from Stripe's API and + invoking the same idempotent helpers the webhook uses. + + Idempotency: if the webhook has already fulfilled this purchase + (status=COMPLETED), the helpers short-circuit and we just return + the latest balance. Concurrent webhook + finalize calls are safe + because both acquire ``SELECT ... FOR UPDATE`` on the purchase row. Authorization: the session's ``client_reference_id`` must match the - authenticated user's id. + authenticated user's id. This prevents a user from finalising + someone else's checkout session if they happen to know the id. """ stripe_client = get_stripe_client() @@ -675,6 +592,9 @@ async def finalize_checkout( detail="Checkout session not found.", ) from exc + # Authorization check: the user finalising must be the user who + # initiated the checkout. ``client_reference_id`` is set in + # ``create_checkout_session`` / ``create_token_checkout_session``. client_reference_id = getattr(checkout_session, "client_reference_id", None) if client_reference_id != str(user.id): logger.warning( @@ -688,75 +608,109 @@ async def finalize_checkout( detail="This checkout session does not belong to you.", ) + metadata = _get_metadata(checkout_session) + is_token = _is_token_purchase(metadata) payment_status = getattr(checkout_session, "payment_status", None) session_status = getattr(checkout_session, "status", None) + + # Defensive fallback: if metadata can't be read for any reason + # (extraction failure, manually-created session in Stripe dashboard, + # SDK upgrade breaking ``to_dict``, etc.) we'd otherwise route every + # purchase to the page_packs handler and get stuck. Resolve the + # purchase_type by checking which table actually has the row keyed + # by this Stripe session id. + if not metadata: + existing_token_purchase = ( + await db_session.execute( + select(PremiumTokenPurchase.id).where( + PremiumTokenPurchase.stripe_checkout_session_id + == str(checkout_session.id) + ) + ) + ).scalar_one_or_none() + if existing_token_purchase is not None: + is_token = True + else: + existing_page_purchase = ( + await db_session.execute( + select(PagePurchase.id).where( + PagePurchase.stripe_checkout_session_id + == str(checkout_session.id) + ) + ) + ).scalar_one_or_none() + if existing_page_purchase is None: + logger.error( + "finalize_checkout: no purchase row in either table " + "and metadata is empty for session=%s user=%s", + session_id, + user.id, + ) + # Fall through; downstream path will short-circuit on + # missing-row + empty-metadata. + logger.info( + "finalize_checkout: recovered purchase_type=%s for session=%s " + "via DB fallback (metadata was empty)", + "premium_tokens" if is_token else "page_packs", + session_id, + ) + is_paid = payment_status in {"paid", "no_payment_required"} is_expired = session_status == "expired" if is_paid: - await _fulfill_completed_credit_purchase(db_session, checkout_session) + if is_token: + await _fulfill_completed_token_purchase(db_session, checkout_session) + else: + await _fulfill_completed_purchase(db_session, checkout_session) elif is_expired: - await _mark_credit_purchase_failed(db_session, str(checkout_session.id)) - # Otherwise leave the row alone — frontend keeps polling and the webhook - # will eventually win the race. + if is_token: + await _mark_token_purchase_failed(db_session, str(checkout_session.id)) + else: + await _mark_purchase_failed(db_session, str(checkout_session.id)) + # Otherwise (e.g. payment_status="unpaid", session_status="open"), + # leave the purchase row alone — frontend will keep polling and the + # webhook will eventually win the race. + # Refresh the user row so the response reflects any update applied + # by the fulfilment helpers in this same session. await db_session.refresh(user) + if is_token: + purchase = ( + await db_session.execute( + select(PremiumTokenPurchase).where( + PremiumTokenPurchase.stripe_checkout_session_id + == str(checkout_session.id) + ) + ) + ).scalar_one_or_none() + return FinalizeCheckoutResponse( + purchase_type="premium_tokens", + status=purchase.status.value if purchase else "pending", + premium_credit_micros_limit=user.premium_credit_micros_limit, + premium_credit_micros_used=user.premium_credit_micros_used, + premium_credit_micros_granted=( + purchase.credit_micros_granted if purchase else None + ), + ) + purchase = ( await db_session.execute( - select(CreditPurchase).where( - CreditPurchase.stripe_checkout_session_id == str(checkout_session.id) + select(PagePurchase).where( + PagePurchase.stripe_checkout_session_id == str(checkout_session.id) ) ) ).scalar_one_or_none() return FinalizeCheckoutResponse( + purchase_type="page_packs", status=purchase.status.value if purchase else "pending", - credit_micros_balance=user.credit_micros_balance, - credit_micros_granted=(purchase.credit_micros_granted if purchase else None), + pages_limit=user.pages_limit, + pages_used=user.pages_used, + pages_granted=purchase.pages_granted if purchase else None, ) -@router.get("/credit-status", response_model=CreditStripeStatusResponse) -async def get_credit_status( - user: User = Depends(current_active_user), -) -> CreditStripeStatusResponse: - """Return credit-buying availability and current balance for the frontend. - - ``credit_micros_balance`` is in micro-USD (1_000_000 = $1.00); the FE - divides by 1M when displaying. - """ - return CreditStripeStatusResponse( - credit_buying_enabled=config.STRIPE_CREDIT_BUYING_ENABLED, - credit_micros_balance=user.credit_micros_balance, - ) - - -@router.get("/credit-purchases", response_model=CreditPurchaseHistoryResponse) -async def get_credit_purchases( - user: User = Depends(current_active_user), - db_session: AsyncSession = Depends(get_async_session), - offset: int = 0, - limit: int = 50, -) -> CreditPurchaseHistoryResponse: - """Return the authenticated user's credit purchase history.""" - limit = min(limit, 100) - purchases = ( - ( - await db_session.execute( - select(CreditPurchase) - .where(CreditPurchase.user_id == user.id) - .order_by(CreditPurchase.created_at.desc()) - .offset(offset) - .limit(limit) - ) - ) - .scalars() - .all() - ) - - return CreditPurchaseHistoryResponse(purchases=purchases) - - @router.get("/purchases", response_model=PagePurchaseHistoryResponse) async def get_page_purchases( user: User = Depends(current_active_user), @@ -764,10 +718,7 @@ async def get_page_purchases( offset: int = 0, limit: int = 50, ) -> PagePurchaseHistoryResponse: - """Return the authenticated user's legacy page-purchase history (read-only). - - Page buying is removed; this endpoint stays for historical records. - """ + """Return the authenticated user's page-purchase history.""" limit = min(limit, 100) purchases = ( ( @@ -786,155 +737,163 @@ async def get_page_purchases( return PagePurchaseHistoryResponse(purchases=purchases) -def _auto_reload_settings_response(user: User) -> AutoReloadSettingsResponse: - return AutoReloadSettingsResponse( - feature_enabled=config.AUTO_RELOAD_ENABLED, - enabled=bool(user.auto_reload_enabled), - threshold_micros=user.auto_reload_threshold_micros, - amount_micros=user.auto_reload_amount_micros, - min_amount_micros=config.AUTO_RELOAD_MIN_AMOUNT_MICROS, - has_payment_method=bool(user.auto_reload_payment_method_id), - failed_at=user.auto_reload_failed_at, - ) +# ============================================================================= +# Premium Token Purchase Routes +# ============================================================================= -@router.post( - "/auto-reload/setup", - response_model=CreateAutoReloadSetupSessionResponse, -) -async def create_auto_reload_setup_session( - body: CreateAutoReloadSetupSessionRequest, - user: User = Depends(current_active_user), - db_session: AsyncSession = Depends(get_async_session), -) -> CreateAutoReloadSetupSessionResponse: - """Start a ``mode=setup`` checkout session to save a card for auto-reload. +def _ensure_token_buying_enabled() -> None: + if not config.STRIPE_TOKEN_BUYING_ENABLED: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Premium token purchases are temporarily unavailable.", + ) - Uses a SetupIntent (no immediate charge) attached to the user's Stripe - Customer so the card can later be charged off-session. On completion the - webhook stores the resulting payment method on the user. - """ - _ensure_auto_reload_enabled() - _ensure_credit_buying_enabled() - stripe_client = get_stripe_client() + +def _get_token_checkout_urls(search_space_id: int) -> tuple[str, str]: if not config.NEXT_FRONTEND_URL: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="NEXT_FRONTEND_URL is not configured.", ) - customer_id = await _get_or_create_stripe_customer(stripe_client, db_session, user) - base_url = config.NEXT_FRONTEND_URL.rstrip("/") + # See ``_get_checkout_urls`` for why session_id is appended. success_url = ( - f"{base_url}/dashboard/{body.search_space_id}/user-settings/purchases" - f"?auto_reload_setup=success" - ) - cancel_url = ( - f"{base_url}/dashboard/{body.search_space_id}/user-settings/purchases" - f"?auto_reload_setup=cancel" + f"{base_url}/dashboard/{search_space_id}/purchase-success" + f"?session_id={{CHECKOUT_SESSION_ID}}" ) + cancel_url = f"{base_url}/dashboard/{search_space_id}/purchase-cancel" + return success_url, cancel_url + + +def _get_required_token_price_id() -> str: + if not config.STRIPE_PREMIUM_TOKEN_PRICE_ID: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="STRIPE_PREMIUM_TOKEN_PRICE_ID is not configured.", + ) + return config.STRIPE_PREMIUM_TOKEN_PRICE_ID + + +@router.post("/create-token-checkout-session") +async def create_token_checkout_session( + body: CreateTokenCheckoutSessionRequest, + user: User = Depends(current_active_user), + db_session: AsyncSession = Depends(get_async_session), +): + """Create a Stripe Checkout Session for buying premium credit packs. + + Each pack grants ``STRIPE_CREDIT_MICROS_PER_UNIT`` micro-USD of + credit (default 1_000_000 = $1.00). The user's balance is debited + at the actual provider cost reported by LiteLLM at finalize time, + so $1 of credit always buys $1 worth of provider usage at cost. + """ + _ensure_token_buying_enabled() + stripe_client = get_stripe_client() + price_id = _get_required_token_price_id() + success_url, cancel_url = _get_token_checkout_urls(body.search_space_id) + credit_micros_granted = body.quantity * config.STRIPE_CREDIT_MICROS_PER_UNIT try: checkout_session = stripe_client.v1.checkout.sessions.create( params={ - "mode": "setup", - # Required in setup mode when payment_method_types is omitted - # (dynamic payment methods); auto-reload charges are in USD. - "currency": "usd", + "mode": "payment", "success_url": success_url, "cancel_url": cancel_url, - "customer": customer_id, + "line_items": [ + { + "price": price_id, + "quantity": body.quantity, + } + ], "client_reference_id": str(user.id), + "customer_email": user.email, "metadata": { "user_id": str(user.id), - "purchase_type": "auto_reload_setup", + "quantity": str(body.quantity), + "credit_micros_per_unit": str(config.STRIPE_CREDIT_MICROS_PER_UNIT), + # Canonical value matched by ``_is_token_purchase``. + # The legacy ``"premium_credit"`` is still accepted on + # the read side for any in-flight sessions started + # before this rename. + "purchase_type": "premium_tokens", }, } ) except StripeError as exc: - logger.exception( - "Failed to create auto-reload setup session for user %s", user.id - ) + logger.exception("Failed to create token checkout session for user %s", user.id) raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, - detail="Unable to create Stripe setup session.", + detail="Unable to create Stripe checkout session.", ) from exc checkout_url = getattr(checkout_session, "url", None) if not checkout_url: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, - detail="Stripe setup session did not return a URL.", + detail="Stripe checkout session did not return a URL.", ) - return CreateAutoReloadSetupSessionResponse(checkout_url=checkout_url) + db_session.add( + PremiumTokenPurchase( + user_id=user.id, + stripe_checkout_session_id=str(checkout_session.id), + stripe_payment_intent_id=_normalize_optional_string( + getattr(checkout_session, "payment_intent", None) + ), + quantity=body.quantity, + credit_micros_granted=credit_micros_granted, + amount_total=getattr(checkout_session, "amount_total", None), + currency=getattr(checkout_session, "currency", None), + status=PremiumTokenPurchaseStatus.PENDING, + ) + ) + await db_session.commit() + + return CreateTokenCheckoutSessionResponse(checkout_url=checkout_url) -@router.get("/auto-reload", response_model=AutoReloadSettingsResponse) -async def get_auto_reload_settings( +@router.get("/token-status") +async def get_token_status( user: User = Depends(current_active_user), -) -> AutoReloadSettingsResponse: - """Return the user's auto-reload configuration and saved-card state.""" - return _auto_reload_settings_response(user) +): + """Return token-buying availability and current premium credit quota for frontend. - -@router.put("/auto-reload", response_model=AutoReloadSettingsResponse) -async def update_auto_reload_settings( - body: UpdateAutoReloadSettingsRequest, - user: User = Depends(current_active_user), - db_session: AsyncSession = Depends(get_async_session), -) -> AutoReloadSettingsResponse: - """Update auto-reload preferences. - - Enabling requires a saved card plus a positive threshold and an amount of - at least ``AUTO_RELOAD_MIN_AMOUNT_MICROS``. Disabling always succeeds and - clears any prior failure flag. + Values are in micro-USD (1_000_000 = $1.00); the FE divides by 1M + when displaying. The route name is preserved for back-compat with + pinned client deployments. """ - _ensure_auto_reload_enabled() - - locked = ( - ( - await db_session.execute( - select(User).where(User.id == user.id).with_for_update(of=User) - ) - ) - .unique() - .scalar_one() + used = user.premium_credit_micros_used + limit = user.premium_credit_micros_limit + return TokenStripeStatusResponse( + token_buying_enabled=config.STRIPE_TOKEN_BUYING_ENABLED, + premium_credit_micros_used=used, + premium_credit_micros_limit=limit, + premium_credit_micros_remaining=max(0, limit - used), ) - if body.enabled: - if not locked.auto_reload_payment_method_id: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Add a payment method before enabling auto-reload.", - ) - if not body.threshold_micros or body.threshold_micros <= 0: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="A positive low-balance threshold is required.", - ) - if ( - body.amount_micros is None - or body.amount_micros < config.AUTO_RELOAD_MIN_AMOUNT_MICROS - ): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=( - "Reload amount must be at least " - f"{config.AUTO_RELOAD_MIN_AMOUNT_MICROS} micro-USD." - ), - ) - locked.auto_reload_enabled = True - locked.auto_reload_threshold_micros = body.threshold_micros - locked.auto_reload_amount_micros = body.amount_micros - # Re-enabling clears the prior failure flag so the user can retry. - locked.auto_reload_failed_at = None - else: - locked.auto_reload_enabled = False - if body.threshold_micros is not None: - locked.auto_reload_threshold_micros = body.threshold_micros - if body.amount_micros is not None: - locked.auto_reload_amount_micros = body.amount_micros - await db_session.commit() - await db_session.refresh(locked) - return _auto_reload_settings_response(locked) +@router.get("/token-purchases") +async def get_token_purchases( + user: User = Depends(current_active_user), + db_session: AsyncSession = Depends(get_async_session), + offset: int = 0, + limit: int = 50, +): + """Return the authenticated user's premium token purchase history.""" + limit = min(limit, 100) + purchases = ( + ( + await db_session.execute( + select(PremiumTokenPurchase) + .where(PremiumTokenPurchase.user_id == user.id) + .order_by(PremiumTokenPurchase.created_at.desc()) + .offset(offset) + .limit(limit) + ) + ) + .scalars() + .all() + ) + + return TokenPurchaseHistoryResponse(purchases=purchases) diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 212a6aa44..fdf34672b 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -68,6 +68,7 @@ from .new_llm_config import ( NewLLMConfigRead, NewLLMConfigUpdate, ) +from .podcasts import PodcastBase, PodcastCreate, PodcastRead, PodcastUpdate from .rbac_schemas import ( InviteAcceptRequest, InviteAcceptResponse, @@ -110,13 +111,11 @@ from .search_space import ( SearchSpaceWithStats, ) from .stripe import ( - CreateCreditCheckoutSessionRequest, - CreateCreditCheckoutSessionResponse, - CreditPurchaseHistoryResponse, - CreditPurchaseRead, - CreditStripeStatusResponse, + CreateCheckoutSessionRequest, + CreateCheckoutSessionResponse, PagePurchaseHistoryResponse, PagePurchaseRead, + StripeStatusResponse, StripeWebhookResponse, ) from .users import UserCreate, UserRead, UserUpdate @@ -144,11 +143,8 @@ __all__ = [ "ChunkCreate", "ChunkRead", "ChunkUpdate", - "CreateCreditCheckoutSessionRequest", - "CreateCreditCheckoutSessionResponse", - "CreditPurchaseHistoryResponse", - "CreditPurchaseRead", - "CreditStripeStatusResponse", + "CreateCheckoutSessionRequest", + "CreateCheckoutSessionResponse", "DefaultSystemInstructionsResponse", # Document schemas "DocumentBase", @@ -236,6 +232,10 @@ __all__ = [ "PermissionInfo", "PermissionsListResponse", # Podcast schemas + "PodcastBase", + "PodcastCreate", + "PodcastRead", + "PodcastUpdate", "RefreshTokenRequest", "RefreshTokenResponse", # Report schemas @@ -257,6 +257,7 @@ __all__ = [ "SearchSpaceRead", "SearchSpaceUpdate", "SearchSpaceWithStats", + "StripeStatusResponse", "StripeWebhookResponse", "ThreadHistoryLoadResponse", "ThreadListItem", diff --git a/surfsense_backend/app/schemas/incentive_tasks.py b/surfsense_backend/app/schemas/incentive_tasks.py index 7b9b39cd1..52c2a5182 100644 --- a/surfsense_backend/app/schemas/incentive_tasks.py +++ b/surfsense_backend/app/schemas/incentive_tasks.py @@ -15,8 +15,7 @@ class IncentiveTaskInfo(BaseModel): task_type: IncentiveTaskType title: str description: str - # Credit reward in USD micro-units (1_000_000 == $1.00). - credit_micros_reward: int + pages_reward: int action_url: str completed: bool completed_at: datetime | None = None @@ -26,7 +25,7 @@ class IncentiveTasksResponse(BaseModel): """Response containing all available incentive tasks with completion status.""" tasks: list[IncentiveTaskInfo] - total_credit_micros_earned: int + total_pages_earned: int class CompleteTaskRequest(BaseModel): @@ -40,8 +39,8 @@ class CompleteTaskResponse(BaseModel): success: bool message: str - credit_micros_awarded: int - new_balance_micros: int + pages_awarded: int + new_pages_limit: int class TaskAlreadyCompletedResponse(BaseModel): diff --git a/surfsense_backend/app/schemas/podcasts.py b/surfsense_backend/app/schemas/podcasts.py new file mode 100644 index 000000000..d41f1ca36 --- /dev/null +++ b/surfsense_backend/app/schemas/podcasts.py @@ -0,0 +1,66 @@ +"""Podcast schemas for API responses.""" + +from datetime import datetime +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel + + +class PodcastStatusEnum(StrEnum): + PENDING = "pending" + GENERATING = "generating" + READY = "ready" + FAILED = "failed" + + +class PodcastBase(BaseModel): + """Base podcast schema.""" + + title: str + podcast_transcript: list[dict[str, Any]] | None = None + file_location: str | None = None + search_space_id: int + + +class PodcastCreate(PodcastBase): + """Schema for creating a podcast.""" + + pass + + +class PodcastUpdate(BaseModel): + """Schema for updating a podcast.""" + + title: str | None = None + podcast_transcript: list[dict[str, Any]] | None = None + file_location: str | None = None + + +class PodcastRead(PodcastBase): + """Schema for reading a podcast.""" + + id: int + status: PodcastStatusEnum = PodcastStatusEnum.READY + created_at: datetime + transcript_entries: int | None = None + + class Config: + from_attributes = True + + @classmethod + def from_orm_with_entries(cls, obj): + """Create PodcastRead with transcript_entries computed.""" + data = { + "id": obj.id, + "title": obj.title, + "podcast_transcript": obj.podcast_transcript, + "file_location": obj.file_location, + "search_space_id": obj.search_space_id, + "status": obj.status, + "created_at": obj.created_at, + "transcript_entries": len(obj.podcast_transcript) + if obj.podcast_transcript + else None, + } + return cls(**data) diff --git a/surfsense_backend/app/schemas/stripe.py b/surfsense_backend/app/schemas/stripe.py index 95c946a3d..ad13ddf04 100644 --- a/surfsense_backend/app/schemas/stripe.py +++ b/surfsense_backend/app/schemas/stripe.py @@ -1,4 +1,4 @@ -"""Schemas for Stripe-backed credit purchases.""" +"""Schemas for Stripe-backed page purchases.""" import uuid from datetime import datetime @@ -8,59 +8,27 @@ from pydantic import BaseModel, ConfigDict, Field from app.db import PagePurchaseStatus -class CreateCreditCheckoutSessionRequest(BaseModel): - """Request body for creating a credit-purchase checkout session.""" +class CreateCheckoutSessionRequest(BaseModel): + """Request body for creating a page-purchase checkout session.""" - quantity: int = Field(ge=1, le=10_000) + quantity: int = Field(ge=1, le=100) search_space_id: int = Field(ge=1) -class CreateCreditCheckoutSessionResponse(BaseModel): +class CreateCheckoutSessionResponse(BaseModel): """Response containing the Stripe-hosted checkout URL.""" checkout_url: str -class CreditPurchaseRead(BaseModel): - """Serialized credit purchase record. +class StripeStatusResponse(BaseModel): + """Response describing Stripe page-buying availability.""" - ``credit_micros_granted`` is in micro-USD (1_000_000 = $1.00). - """ - - id: uuid.UUID - stripe_checkout_session_id: str - stripe_payment_intent_id: str | None = None - quantity: int - credit_micros_granted: int - amount_total: int | None = None - currency: str | None = None - source: str = "checkout" - status: str - completed_at: datetime | None = None - created_at: datetime - - model_config = ConfigDict(from_attributes=True) - - -class CreditPurchaseHistoryResponse(BaseModel): - """Response containing the user's credit purchases.""" - - purchases: list[CreditPurchaseRead] - - -class CreditStripeStatusResponse(BaseModel): - """Response describing credit-buying availability and current balance. - - ``credit_micros_balance`` is in micro-USD; the FE divides by 1_000_000 - to display USD. - """ - - credit_buying_enabled: bool - credit_micros_balance: int = 0 + page_buying_enabled: bool class PagePurchaseRead(BaseModel): - """Serialized legacy page-purchase record (read-only history).""" + """Serialized page-purchase record for purchase history.""" id: uuid.UUID stripe_checkout_session_id: str @@ -77,52 +45,11 @@ class PagePurchaseRead(BaseModel): class PagePurchaseHistoryResponse(BaseModel): - """Response containing the authenticated user's legacy page purchases.""" + """Response containing the authenticated user's page purchases.""" purchases: list[PagePurchaseRead] -class AutoReloadSettingsResponse(BaseModel): - """Auto-reload configuration + saved-card state for the settings UI. - - All ``*_micros`` fields are micro-USD (1_000_000 == $1.00). ``feature_enabled`` - reflects the server-side ``AUTO_RELOAD_ENABLED`` flag; when it is false the - UI should hide / disable the auto-reload controls entirely. - """ - - feature_enabled: bool - enabled: bool = False - threshold_micros: int | None = None - amount_micros: int | None = None - min_amount_micros: int - has_payment_method: bool = False - failed_at: datetime | None = None - - -class UpdateAutoReloadSettingsRequest(BaseModel): - """Update auto-reload preferences. - - Enabling requires a saved card (set up via /stripe/auto-reload/setup) plus a - positive threshold and an amount of at least ``AUTO_RELOAD_MIN_AMOUNT_MICROS``. - """ - - enabled: bool - threshold_micros: int | None = Field(default=None, ge=0) - amount_micros: int | None = Field(default=None, ge=0) - - -class CreateAutoReloadSetupSessionRequest(BaseModel): - """Request body for starting the save-a-card (SetupIntent) checkout.""" - - search_space_id: int = Field(ge=1) - - -class CreateAutoReloadSetupSessionResponse(BaseModel): - """Response containing the Stripe-hosted setup (save-card) checkout URL.""" - - checkout_url: str - - class StripeWebhookResponse(BaseModel): """Generic acknowledgement for Stripe webhook delivery.""" @@ -139,6 +66,64 @@ class FinalizeCheckoutResponse(BaseModel): endpoint until it sees ``completed`` or a final ``failed``. """ + purchase_type: str # "page_packs" | "premium_tokens" + status: str # PagePurchaseStatus / PremiumTokenPurchaseStatus value + pages_limit: int | None = None + pages_used: int | None = None + pages_granted: int | None = None + premium_credit_micros_limit: int | None = None + premium_credit_micros_used: int | None = None + premium_credit_micros_granted: int | None = None + + +class CreateTokenCheckoutSessionRequest(BaseModel): + """Request body for creating a premium token purchase checkout session.""" + + quantity: int = Field(ge=1, le=100) + search_space_id: int = Field(ge=1) + + +class CreateTokenCheckoutSessionResponse(BaseModel): + """Response containing the Stripe-hosted checkout URL.""" + + checkout_url: str + + +class TokenPurchaseRead(BaseModel): + """Serialized premium credit purchase record. + + ``credit_micros_granted`` is in micro-USD (1_000_000 = $1.00). The + schema name kept ``Token`` for API back-compat with pinned clients. + """ + + id: uuid.UUID + stripe_checkout_session_id: str + stripe_payment_intent_id: str | None = None + quantity: int + credit_micros_granted: int + amount_total: int | None = None + currency: str | None = None status: str - credit_micros_balance: int = 0 - credit_micros_granted: int | None = None + completed_at: datetime | None = None + created_at: datetime + + model_config = ConfigDict(from_attributes=True) + + +class TokenPurchaseHistoryResponse(BaseModel): + """Response containing the user's premium credit purchases.""" + + purchases: list[TokenPurchaseRead] + + +class TokenStripeStatusResponse(BaseModel): + """Response describing premium-credit-buying availability and balance. + + All ``premium_credit_micros_*`` fields are in micro-USD; the FE + divides by 1_000_000 to display USD. + """ + + token_buying_enabled: bool + premium_credit_micros_used: int = 0 + premium_credit_micros_limit: int = 0 + premium_credit_micros_remaining: int = 0 diff --git a/surfsense_backend/app/schemas/users.py b/surfsense_backend/app/schemas/users.py index 558463f57..88d0a4f37 100644 --- a/surfsense_backend/app/schemas/users.py +++ b/surfsense_backend/app/schemas/users.py @@ -4,7 +4,8 @@ from fastapi_users import schemas class UserRead(schemas.BaseUser[uuid.UUID]): - credit_micros_balance: int + pages_limit: int + pages_used: int display_name: str | None = None avatar_url: str | None = None diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index c9fd8c315..9bbca8669 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -268,7 +268,7 @@ async def _is_premium_eligible( parsed = _to_uuid(user_id) if parsed is None: return False - usage = await TokenQuotaService.credit_get_usage(session, parsed) + usage = await TokenQuotaService.premium_get_usage(session, parsed) return bool(usage.allowed) diff --git a/surfsense_backend/app/services/auto_reload_service.py b/surfsense_backend/app/services/auto_reload_service.py deleted file mode 100644 index 9f5114a56..000000000 --- a/surfsense_backend/app/services/auto_reload_service.py +++ /dev/null @@ -1,99 +0,0 @@ -"""Debit-triggered credit auto-reload. - -``maybe_trigger_auto_reload`` is a cheap, best-effort pre-filter invoked after -every credit debit (ETL ``charge_credits`` and premium ``credit_finalize``). -When the wallet drops below the user's configured threshold it enqueues the -Celery task that performs the authoritative re-check and the off-session Stripe -charge. All real safety (row lock, cooldown, Stripe idempotency) lives in the -task — this function only avoids enqueuing work that obviously isn't needed. - -Everything here is gated behind ``config.AUTO_RELOAD_ENABLED``; when the flag is -off this module is inert. -""" - -from __future__ import annotations - -import logging -from datetime import UTC, datetime, timedelta - -from sqlalchemy import select - -from app.config import config - -logger = logging.getLogger(__name__) - - -async def maybe_trigger_auto_reload(user_id: str) -> None: - """Enqueue an auto-reload charge if the user's balance fell below threshold. - - Best-effort: any failure is swallowed by the caller. Opens its own - short-lived session so it never interferes with the caller's transaction - (it always runs after the caller has already committed the debit). - """ - if not config.AUTO_RELOAD_ENABLED: - return - - from app.db import CreditPurchase, CreditPurchaseStatus, User, async_session_maker - - async with async_session_maker() as session: - user = ( - (await session.execute(select(User).where(User.id == user_id))) - .unique() - .scalar_one_or_none() - ) - if user is None or not user.auto_reload_enabled: - return - - if not (user.stripe_customer_id and user.auto_reload_payment_method_id): - return - - threshold = user.auto_reload_threshold_micros - amount = user.auto_reload_amount_micros - if not threshold or not amount: - return - - available = user.credit_micros_balance - user.credit_micros_reserved - if available >= threshold: - return - - # Cheap cooldown pre-check: skip if a recent auto-reload purchase exists - # or a recent attempt failed (avoids hammering a declined card). - cutoff = datetime.now(UTC) - timedelta( - minutes=max(config.AUTO_RELOAD_COOLDOWN_MINUTES, 0) - ) - if user.auto_reload_failed_at and user.auto_reload_failed_at >= cutoff: - return - recent = ( - await session.execute( - select(CreditPurchase.id) - .where( - CreditPurchase.user_id == user.id, - CreditPurchase.source == "auto_reload", - CreditPurchase.created_at >= cutoff, - CreditPurchase.status.in_( - [ - CreditPurchaseStatus.PENDING, - CreditPurchaseStatus.COMPLETED, - ] - ), - ) - .limit(1) - ) - ).first() - if recent is not None: - return - - # Enqueue outside the session. The task re-checks everything with a row - # lock before charging, so a benign race here only costs a no-op task run. - try: - from app.tasks.celery_tasks.auto_reload_task import ( - auto_reload_credits_task, - ) - - auto_reload_credits_task.delay(str(user_id)) - except Exception: - logger.warning( - "Failed to enqueue auto_reload_credits task for user %s", - user_id, - exc_info=True, - ) diff --git a/surfsense_backend/app/services/billable_calls.py b/surfsense_backend/app/services/billable_calls.py index 919c49a21..92ccd6a78 100644 --- a/surfsense_backend/app/services/billable_calls.py +++ b/surfsense_backend/app/services/billable_calls.py @@ -69,8 +69,8 @@ BillableSessionFactory = Callable[[], AbstractAsyncContextManager[AsyncSession]] class QuotaInsufficientError(Exception): - """Raised when ``TokenQuotaService.credit_reserve`` denies a billable - call because the user has exhausted their credit wallet. + """Raised when ``TokenQuotaService.premium_reserve`` denies a billable + call because the user has exhausted their premium credit pool. The route handler should catch this and return HTTP 402 Payment Required (or the equivalent for the surface area). Outside of the HTTP @@ -83,15 +83,17 @@ class QuotaInsufficientError(Exception): self, *, usage_type: str, - balance_micros: int, + used_micros: int, + limit_micros: int, remaining_micros: int, ) -> None: self.usage_type = usage_type - self.balance_micros = balance_micros + self.used_micros = used_micros + self.limit_micros = limit_micros self.remaining_micros = remaining_micros super().__init__( - f"Credit exhausted for {usage_type}: " - f"balance={balance_micros} remaining={remaining_micros} (micro-USD)" + f"Premium credit exhausted for {usage_type}: " + f"used={used_micros} limit={limit_micros} remaining={remaining_micros} (micro-USD)" ) @@ -265,7 +267,7 @@ async def billable_call( ``TokenTrackingCallback`` populates the accumulator automatically. Raises: - QuotaInsufficientError: when premium and ``credit_reserve`` denies. + QuotaInsufficientError: when premium and ``premium_reserve`` denies. """ is_premium = billing_tier == "premium" session_factory = billable_session_factory or shielded_async_session @@ -308,7 +310,7 @@ async def billable_call( request_id = str(uuid4()) async with session_factory() as quota_session: - reserve_result = await TokenQuotaService.credit_reserve( + reserve_result = await TokenQuotaService.premium_reserve( db_session=quota_session, user_id=user_id, request_id=request_id, @@ -318,16 +320,18 @@ async def billable_call( if not reserve_result.allowed: logger.info( "[billable_call] reserve DENIED user=%s usage_type=%s " - "reserve=%d balance=%d remaining=%d", + "reserve=%d used=%d limit=%d remaining=%d", user_id, usage_type, reserve_micros, - reserve_result.balance, + reserve_result.used, + reserve_result.limit, reserve_result.remaining, ) raise QuotaInsufficientError( usage_type=usage_type, - balance_micros=reserve_result.balance, + used_micros=reserve_result.used, + limit_micros=reserve_result.limit, remaining_micros=reserve_result.remaining, ) @@ -348,14 +352,14 @@ async def billable_call( # BaseException so cancellation also releases. try: async with session_factory() as quota_session: - await TokenQuotaService.credit_release( + await TokenQuotaService.premium_release( db_session=quota_session, user_id=user_id, reserved_micros=reserve_micros, ) except Exception: logger.exception( - "[billable_call] credit_release failed for user=%s " + "[billable_call] premium_release failed for user=%s " "reserve_micros=%d (reservation will be GC'd by quota " "reconciliation if/when implemented)", user_id, @@ -376,7 +380,7 @@ async def billable_call( thread_id, ) async with session_factory() as quota_session: - final_result = await TokenQuotaService.credit_finalize( + final_result = await TokenQuotaService.premium_finalize( db_session=quota_session, user_id=user_id, request_id=request_id, @@ -385,25 +389,26 @@ async def billable_call( ) logger.info( "[billable_call] finalize user=%s usage_type=%s actual=%d " - "reserved=%d → balance=%d (remaining=%d)", + "reserved=%d → used=%d/%d (remaining=%d)", user_id, usage_type, actual_micros, reserve_micros, - final_result.balance, + final_result.used, + final_result.limit, final_result.remaining, ) except Exception as finalize_exc: # Last-ditch: if finalize itself fails, we must at least release # so the reservation doesn't leak. logger.exception( - "[billable_call] credit_finalize failed for user=%s; " + "[billable_call] premium_finalize failed for user=%s; " "attempting release", user_id, ) try: async with session_factory() as quota_session: - await TokenQuotaService.credit_release( + await TokenQuotaService.premium_release( db_session=quota_session, user_id=user_id, reserved_micros=reserve_micros, @@ -460,7 +465,7 @@ async def _resolve_agent_billing_for_search_space( so the same model bills for chat + downstream podcast/video. If the user is not premium-eligible, the pin service auto-restricts to free deployments — denial only happens later in - ``billable_call.credit_reserve`` if the pin really is premium and + ``billable_call.premium_reserve`` if the pin really is premium and credit ran out mid-flow. * ``thread_id`` is None: fallback to ``("free", "auto")``. Forward-compat for any future direct-API path; today both Celery tasks always pass diff --git a/surfsense_backend/app/services/etl_credit_service.py b/surfsense_backend/app/services/page_limit_service.py similarity index 67% rename from surfsense_backend/app/services/etl_credit_service.py rename to surfsense_backend/app/services/page_limit_service.py index 5c4ea4bbd..47fe07fc6 100644 --- a/surfsense_backend/app/services/etl_credit_service.py +++ b/surfsense_backend/app/services/page_limit_service.py @@ -1,14 +1,5 @@ """ -Service for charging the unified credit wallet for ETL document processing. - -Replaces the legacy ``PageLimitService`` page-quota model. Page counts are -still estimated the same way; they are now converted to USD micro-credits -(``config.MICROS_PER_PAGE`` per page, times a per-mode multiplier) and debited -from ``user.credit_micros_balance``. - -When ``config.ETL_CREDIT_BILLING_ENABLED`` is False (the default for -self-hosted / OSS installs) every check/charge is a no-op, preserving the prior -effectively-unlimited ETL behaviour. +Service for managing user page limits for ETL services. """ import os @@ -17,125 +8,141 @@ from pathlib import Path, PurePosixPath from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.config import config - -class InsufficientCreditsError(Exception): - """Raised when a user lacks enough credit to process a document.""" +class PageLimitExceededError(Exception): + """ + Exception raised when a user exceeds their page processing limit. + """ def __init__( self, - message: str = "Insufficient credits to process this document. " - "Add more credits to continue.", - balance_micros: int = 0, - required_micros: int = 0, + message: str = "Page limit exceeded. Please contact admin to increase limits for your account.", + pages_used: int = 0, + pages_limit: int = 0, + pages_to_add: int = 0, ): - self.balance_micros = balance_micros - self.required_micros = required_micros + self.pages_used = pages_used + self.pages_limit = pages_limit + self.pages_to_add = pages_to_add super().__init__(message) -class EtlCreditService: - """Checks and charges the credit wallet for ETL page processing.""" +class PageLimitService: + """Service for checking and updating user page limits.""" def __init__(self, session: AsyncSession): self.session = session - @staticmethod - def billing_enabled() -> bool: - return config.ETL_CREDIT_BILLING_ENABLED - - @staticmethod - def pages_to_micros(pages: int, multiplier: int = 1) -> int: - """Convert a (multiplied) page count to USD micro-credits.""" - return int(pages) * int(multiplier) * config.MICROS_PER_PAGE - - async def get_available_micros(self, user_id: str) -> int | None: - """Return spendable credit in micro-USD (``balance - reserved``). - - Returns ``None`` when ETL billing is disabled, which callers treat as - "unlimited" (no batch skipping, no blocking). + async def check_page_limit( + self, user_id: str, estimated_pages: int = 1 + ) -> tuple[bool, int, int]: """ - if not config.ETL_CREDIT_BILLING_ENABLED: - return None + Check if user has enough pages remaining for processing. + Args: + user_id: The user's ID + estimated_pages: Estimated number of pages to be processed + + Returns: + Tuple of (has_capacity, pages_used, pages_limit) + + Raises: + PageLimitExceededError: If user would exceed their page limit + """ from app.db import User + # Get user's current page usage result = await self.session.execute( - select(User.credit_micros_balance, User.credit_micros_reserved).where( - User.id == user_id - ) + select(User.pages_used, User.pages_limit).where(User.id == user_id) ) row = result.first() + if not row: raise ValueError(f"User with ID {user_id} not found") - balance, reserved = row - return balance - reserved + pages_used, pages_limit = row - async def check_credits( - self, user_id: str, estimated_pages: int = 1, multiplier: int = 1 - ) -> None: - """Raise :class:`InsufficientCreditsError` if the user can't afford to - process ``estimated_pages`` (times ``multiplier``). - - No-op when ETL billing is disabled. - """ - if not config.ETL_CREDIT_BILLING_ENABLED: - return - - required = self.pages_to_micros(estimated_pages, multiplier) - available = await self.get_available_micros(user_id) - if available is None: - return - - if required > available: - raise InsufficientCreditsError( - message=( - "Processing this document would exceed your available " - f"credit. Available: ${available / 1_000_000:.2f}. " - f"This document costs about ${required / 1_000_000:.2f} " - f"({estimated_pages} page(s)). Add more credits to continue." - ), - balance_micros=available, - required_micros=required, + # Check if adding estimated pages would exceed limit + if pages_used + estimated_pages > pages_limit: + raise PageLimitExceededError( + message=f"Processing this document would exceed your page limit. " + f"Used: {pages_used}/{pages_limit} pages. " + f"Document has approximately {estimated_pages} page(s). " + f"Please contact admin to increase limits for your account.", + pages_used=pages_used, + pages_limit=pages_limit, + pages_to_add=estimated_pages, ) - async def charge_credits( - self, user_id: str, pages: int, multiplier: int = 1 - ) -> int | None: - """Debit the credit wallet after successful processing. + return True, pages_used, pages_limit - The balance may dip slightly negative when the actual page count - exceeds the pre-check estimate (the document is already processed), - mirroring the prior ``allow_exceed=True`` semantics. - - Returns the new balance in micros, or ``None`` when billing is disabled. + async def update_page_usage( + self, user_id: str, pages_to_add: int, allow_exceed: bool = False + ) -> int: """ - if not config.ETL_CREDIT_BILLING_ENABLED: - return None + Update user's page usage after successful processing. + Args: + user_id: The user's ID + pages_to_add: Number of pages to add to usage + allow_exceed: If True, allows update even if it exceeds limit + (used when document was already processed after passing initial check) + + Returns: + New total pages_used value + + Raises: + PageLimitExceededError: If adding pages would exceed limit and allow_exceed is False + """ from app.db import User + # Get user result = await self.session.execute(select(User).where(User.id == user_id)) user = result.unique().scalar_one_or_none() + if not user: raise ValueError(f"User with ID {user_id} not found") - cost = self.pages_to_micros(pages, multiplier) - user.credit_micros_balance -= cost + # Check if this would exceed limit (only if allow_exceed is False) + new_usage = user.pages_used + pages_to_add + if not allow_exceed and new_usage > user.pages_limit: + raise PageLimitExceededError( + message=f"Cannot update page usage. Would exceed limit. " + f"Current: {user.pages_used}/{user.pages_limit}, " + f"Trying to add: {pages_to_add}", + pages_used=user.pages_used, + pages_limit=user.pages_limit, + pages_to_add=pages_to_add, + ) + + # Update usage + user.pages_used = new_usage await self.session.commit() await self.session.refresh(user) - # Best-effort: fire an auto-reload check if the balance dropped low. - try: - from app.services.auto_reload_service import maybe_trigger_auto_reload + return user.pages_used - await maybe_trigger_auto_reload(user_id) - except Exception: - pass + async def get_page_usage(self, user_id: str) -> tuple[int, int]: + """ + Get user's current page usage and limit. - return user.credit_micros_balance + Args: + user_id: The user's ID + + Returns: + Tuple of (pages_used, pages_limit) + """ + from app.db import User + + result = await self.session.execute( + select(User.pages_used, User.pages_limit).where(User.id == user_id) + ) + row = result.first() + + if not row: + raise ValueError(f"User with ID {user_id} not found") + + return row def estimate_pages_from_elements(self, elements: list) -> int: """ diff --git a/surfsense_backend/app/services/public_chat_service.py b/surfsense_backend/app/services/public_chat_service.py index d17f411b8..e4e0dd33a 100644 --- a/surfsense_backend/app/services/public_chat_service.py +++ b/surfsense_backend/app/services/public_chat_service.py @@ -337,9 +337,6 @@ async def _get_podcast_for_snapshot( "original_id": podcast.id, "title": podcast.title, "transcript": podcast.podcast_transcript, - "storage_backend": podcast.storage_backend, - "storage_key": podcast.storage_key, - # Legacy fallback for rows rendered before the storage migration. "file_path": podcast.file_location, } @@ -720,8 +717,6 @@ async def clone_from_snapshot( new_podcast = Podcast( title=podcast_info.get("title", "Cloned Podcast"), podcast_transcript=podcast_info.get("transcript"), - storage_backend=podcast_info.get("storage_backend"), - storage_key=podcast_info.get("storage_key"), file_location=podcast_info.get("file_path"), status=PodcastStatus.READY, search_space_id=target_search_space_id, diff --git a/surfsense_backend/app/services/token_quota_service.py b/surfsense_backend/app/services/token_quota_service.py index d32c18722..310c3eb5e 100644 --- a/surfsense_backend/app/services/token_quota_service.py +++ b/surfsense_backend/app/services/token_quota_service.py @@ -99,18 +99,7 @@ class QuotaStatus(StrEnum): class QuotaResult: - # ``used``/``limit`` are used by the anonymous (Redis) token path. - # ``balance``/``remaining``/``reserved`` are used by the credit (Postgres) - # path, all in USD micro-units. ``remaining`` == spendable (balance - reserved). - __slots__ = ( - "allowed", - "balance", - "limit", - "remaining", - "reserved", - "status", - "used", - ) + __slots__ = ("allowed", "limit", "remaining", "reserved", "status", "used") def __init__( self, @@ -120,7 +109,6 @@ class QuotaResult: limit: int, reserved: int = 0, remaining: int = 0, - balance: int = 0, ): self.allowed = allowed self.status = status @@ -128,7 +116,6 @@ class QuotaResult: self.limit = limit self.reserved = reserved self.remaining = remaining - self.balance = balance def to_dict(self) -> dict[str, Any]: return { @@ -138,7 +125,6 @@ class QuotaResult: "limit": self.limit, "reserved": self.reserved, "remaining": self.remaining, - "balance": self.balance, } @@ -519,33 +505,19 @@ class TokenQuotaService: # ------------------------------------------------------------------ @staticmethod - def _credit_status(balance: int) -> QuotaStatus: - """Map a spendable balance to OK / WARNING / BLOCKED. - - There is no longer a fixed ceiling, so WARNING fires on a low absolute - balance (``config.CREDIT_LOW_BALANCE_WARNING_MICROS``) instead of a - percentage of a limit. - """ - if balance <= 0: - return QuotaStatus.BLOCKED - if balance < config.CREDIT_LOW_BALANCE_WARNING_MICROS: - return QuotaStatus.WARNING - return QuotaStatus.OK - - @staticmethod - async def credit_reserve( + async def premium_reserve( db_session: AsyncSession, user_id: Any, request_id: str, reserve_micros: int, ) -> QuotaResult: - """Reserve ``reserve_micros`` (USD micro-units) from the user's credit - wallet. + """Reserve ``reserve_micros`` (USD micro-units) from the user's + premium credit balance. - ``QuotaResult.balance``/``reserved``/``remaining`` are in micro-USD on - this code path; callers (chat stream, credit-status route, FE display) - convert to dollars by dividing by 1_000_000. ``remaining`` is the - spendable amount (``balance - reserved``). + ``QuotaResult.used``/``limit``/``reserved``/``remaining`` are + all in micro-USD on this code path; callers (chat stream, + token-status route, FE display) convert to dollars by dividing + by 1_000_000. """ from app.db import User @@ -566,41 +538,48 @@ class TokenQuotaService: limit=0, ) - balance = user.credit_micros_balance - reserved = user.credit_micros_reserved + limit = user.premium_credit_micros_limit + used = user.premium_credit_micros_used + reserved = user.premium_credit_micros_reserved - # Block when the new hold would exceed the spendable balance. - if reserved + reserve_micros > balance: - remaining = max(0, balance - reserved) + effective = used + reserved + reserve_micros + if effective > limit: + remaining = max(0, limit - used - reserved) await db_session.rollback() return QuotaResult( allowed=False, status=QuotaStatus.BLOCKED, - used=0, - limit=balance, + used=used, + limit=limit, reserved=reserved, remaining=remaining, - balance=balance, ) - user.credit_micros_reserved = reserved + reserve_micros + user.premium_credit_micros_reserved = reserved + reserve_micros await db_session.commit() new_reserved = reserved + reserve_micros - remaining = max(0, balance - new_reserved) + remaining = max(0, limit - used - new_reserved) + warning_threshold = int(limit * 0.8) + + if (used + new_reserved) >= limit: + status = QuotaStatus.BLOCKED + elif (used + new_reserved) >= warning_threshold: + status = QuotaStatus.WARNING + else: + status = QuotaStatus.OK return QuotaResult( allowed=True, - status=TokenQuotaService._credit_status(remaining), - used=0, - limit=balance, + status=status, + used=used, + limit=limit, reserved=new_reserved, remaining=remaining, - balance=balance, ) @staticmethod - async def credit_finalize( + async def premium_finalize( db_session: AsyncSession, user_id: Any, request_id: str, @@ -608,8 +587,7 @@ class TokenQuotaService: reserved_micros: int, ) -> QuotaResult: """Settle the reservation: release ``reserved_micros`` and debit - ``actual_micros`` (the LiteLLM-reported provider cost in micro-USD) - from the balance. + ``actual_micros`` (the LiteLLM-reported provider cost in micro-USD). """ from app.db import User @@ -627,42 +605,44 @@ class TokenQuotaService: allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0 ) - user.credit_micros_reserved = max( - 0, user.credit_micros_reserved - reserved_micros + user.premium_credit_micros_reserved = max( + 0, user.premium_credit_micros_reserved - reserved_micros + ) + user.premium_credit_micros_used = ( + user.premium_credit_micros_used + actual_micros ) - user.credit_micros_balance = user.credit_micros_balance - actual_micros await db_session.commit() - balance = user.credit_micros_balance - reserved = user.credit_micros_reserved - remaining = max(0, balance - reserved) + limit = user.premium_credit_micros_limit + used = user.premium_credit_micros_used + reserved = user.premium_credit_micros_reserved + remaining = max(0, limit - used - reserved) - # Best-effort auto-reload nudge after the debit settles. - try: - from app.services.auto_reload_service import maybe_trigger_auto_reload - - await maybe_trigger_auto_reload(user_id) - except Exception: - pass + warning_threshold = int(limit * 0.8) + if used >= limit: + status = QuotaStatus.BLOCKED + elif used >= warning_threshold: + status = QuotaStatus.WARNING + else: + status = QuotaStatus.OK return QuotaResult( allowed=True, - status=TokenQuotaService._credit_status(remaining), - used=0, - limit=balance, + status=status, + used=used, + limit=limit, reserved=reserved, remaining=remaining, - balance=balance, ) @staticmethod - async def credit_release( + async def premium_release( db_session: AsyncSession, user_id: Any, reserved_micros: int, ) -> None: - """Release ``reserved_micros`` previously held by ``credit_reserve``. + """Release ``reserved_micros`` previously held by ``premium_reserve``. Used when a request fails before finalize (so the reservation doesn't leak credit). @@ -679,13 +659,13 @@ class TokenQuotaService: .scalar_one_or_none() ) if user is not None: - user.credit_micros_reserved = max( - 0, user.credit_micros_reserved - reserved_micros + user.premium_credit_micros_reserved = max( + 0, user.premium_credit_micros_reserved - reserved_micros ) await db_session.commit() @staticmethod - async def credit_get_usage( + async def premium_get_usage( db_session: AsyncSession, user_id: Any, ) -> QuotaResult: @@ -701,16 +681,24 @@ class TokenQuotaService: allowed=False, status=QuotaStatus.BLOCKED, used=0, limit=0 ) - balance = user.credit_micros_balance - reserved = user.credit_micros_reserved - remaining = max(0, balance - reserved) + limit = user.premium_credit_micros_limit + used = user.premium_credit_micros_used + reserved = user.premium_credit_micros_reserved + remaining = max(0, limit - used - reserved) + + warning_threshold = int(limit * 0.8) + if used >= limit: + status = QuotaStatus.BLOCKED + elif used >= warning_threshold: + status = QuotaStatus.WARNING + else: + status = QuotaStatus.OK return QuotaResult( - allowed=remaining > 0, - status=TokenQuotaService._credit_status(remaining), - used=0, - limit=balance, + allowed=used < limit, + status=status, + used=used, + limit=limit, reserved=reserved, remaining=remaining, - balance=balance, ) diff --git a/surfsense_backend/app/tasks/celery_tasks/auto_reload_task.py b/surfsense_backend/app/tasks/celery_tasks/auto_reload_task.py deleted file mode 100644 index 385cdde88..000000000 --- a/surfsense_backend/app/tasks/celery_tasks/auto_reload_task.py +++ /dev/null @@ -1,296 +0,0 @@ -"""Debit-triggered off-session credit auto-reload. - -Enqueued (best-effort) by ``auto_reload_service.maybe_trigger_auto_reload`` -after a credit debit drops the wallet below the user's threshold. This task is -the authoritative path: it re-checks eligibility under a row lock, enforces the -cooldown, then charges the saved card off-session via a Stripe PaymentIntent -(Stripe: charging a saved card off-session). - -Idempotency comes from three layers: -- a per-attempt CreditPurchase row created PENDING before the charge, -- a Stripe idempotency key derived from that row id, -- the ``payment_intent.*`` webhook backstop in ``stripe_routes`` that only - transitions PENDING rows. -""" - -from __future__ import annotations - -import logging -import uuid -from datetime import UTC, datetime, timedelta - -from sqlalchemy import select -from stripe import CardError, StripeClient, StripeError - -from app.celery_app import celery_app -from app.config import config -from app.db import CreditPurchase, CreditPurchaseStatus, User -from app.notifications.service import NotificationService -from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task - -logger = logging.getLogger(__name__) - -# 1_000_000 micro-USD == $1.00 == 100 cents, so cents = micros / 10_000. -_MICROS_PER_CENT = 10_000 - - -def _get_stripe_client() -> StripeClient | None: - if not config.STRIPE_SECRET_KEY: - logger.warning("Auto-reload skipped because STRIPE_SECRET_KEY is not set.") - return None - return StripeClient(config.STRIPE_SECRET_KEY) - - -def _card_error_payment_intent_id(exc: CardError) -> str | None: - """Pull the PaymentIntent id off a declined off-session charge. - - Per Stripe's off-session guide the failed intent is on ``exc.error.payment_intent``, - which may be a StripeObject or a plain dict depending on the SDK path. - """ - err = getattr(exc, "error", None) - pi = getattr(err, "payment_intent", None) if err is not None else None - if pi is None: - return None - if isinstance(pi, dict): - return pi.get("id") - return getattr(pi, "id", None) - - -@celery_app.task(name="auto_reload_credits") -def auto_reload_credits_task(user_id: str): - """Charge the user's saved card to top up credits when below threshold.""" - return run_async_celery_task(lambda: _auto_reload_credits(user_id)) - - -async def _auto_reload_credits(user_id: str) -> None: - if not config.AUTO_RELOAD_ENABLED: - return - - stripe_client = _get_stripe_client() - if stripe_client is None: - return - - cooldown = timedelta(minutes=max(config.AUTO_RELOAD_COOLDOWN_MINUTES, 0)) - now = datetime.now(UTC) - cutoff = now - cooldown - - async with get_celery_session_maker()() as db_session: - # Lock the user row so concurrent debits/tasks can't double-charge. - user = ( - ( - await db_session.execute( - select(User) - .where(User.id == uuid.UUID(user_id)) - .with_for_update(of=User) - ) - ) - .unique() - .scalar_one_or_none() - ) - if user is None or not user.auto_reload_enabled: - return - - if not (user.stripe_customer_id and user.auto_reload_payment_method_id): - return - - threshold = user.auto_reload_threshold_micros - amount = user.auto_reload_amount_micros - if not threshold or not amount: - return - - available = user.credit_micros_balance - user.credit_micros_reserved - if available >= threshold: - # Another reload (or a refund/grant) already restored the balance. - return - - # Cooldown: skip if a recent auto-reload purchase or failure happened. - recent = ( - await db_session.execute( - select(CreditPurchase.id) - .where( - CreditPurchase.user_id == user.id, - CreditPurchase.source == "auto_reload", - CreditPurchase.created_at >= cutoff, - CreditPurchase.status.in_( - [ - CreditPurchaseStatus.PENDING, - CreditPurchaseStatus.COMPLETED, - ] - ), - ) - .limit(1) - ) - ).first() - if recent is not None: - return - if user.auto_reload_failed_at and user.auto_reload_failed_at >= cutoff: - return - - customer_id = user.stripe_customer_id - payment_method_id = user.auto_reload_payment_method_id - amount_cents = max(round(amount / _MICROS_PER_CENT), 1) - - # Create the PENDING purchase row first so its id seeds the Stripe - # idempotency key and the webhook backstop can find it. - purchase = CreditPurchase( - user_id=user.id, - stripe_checkout_session_id=f"auto_reload:{uuid.uuid4()}", - quantity=0, - credit_micros_granted=amount, - amount_total=amount_cents, - currency="usd", - source="auto_reload", - status=CreditPurchaseStatus.PENDING, - ) - db_session.add(purchase) - await db_session.flush() - purchase_id = purchase.id - await db_session.commit() - - # Charge off-session outside the user-row lock so the network call doesn't - # hold the row. The purchase row is the synchronization point now. - try: - payment_intent = stripe_client.v1.payment_intents.create( - params={ - "amount": amount_cents, - "currency": "usd", - "customer": customer_id, - "payment_method": payment_method_id, - "off_session": True, - "confirm": True, - "metadata": { - "user_id": str(user_id), - "purchase_type": "auto_reload", - "purchase_id": str(purchase_id), - }, - }, - options={"idempotency_key": f"auto_reload:{purchase_id}"}, - ) - except CardError as exc: - await _record_failure( - purchase_id, - user_id, - amount, - payment_intent_id=_card_error_payment_intent_id(exc), - reason=getattr(exc, "user_message", None) or "Your card was declined.", - ) - return - except StripeError: - logger.exception("Auto-reload charge failed for user %s", user_id) - await _record_failure( - purchase_id, - user_id, - amount, - payment_intent_id=None, - reason="We couldn't process the charge. Please try again.", - ) - return - - payment_intent_id = str(payment_intent.id) - pi_status = getattr(payment_intent, "status", None) - - async with get_celery_session_maker()() as db_session: - purchase = ( - await db_session.execute( - select(CreditPurchase) - .where(CreditPurchase.id == purchase_id) - .with_for_update() - ) - ).scalar_one_or_none() - if purchase is None: - return - purchase.stripe_payment_intent_id = payment_intent_id - - if pi_status == "succeeded": - if purchase.status != CreditPurchaseStatus.COMPLETED: - user = ( - ( - await db_session.execute( - select(User) - .where(User.id == purchase.user_id) - .with_for_update(of=User) - ) - ) - .unique() - .scalar_one() - ) - purchase.status = CreditPurchaseStatus.COMPLETED - purchase.completed_at = datetime.now(UTC) - user.credit_micros_balance = ( - user.credit_micros_balance + purchase.credit_micros_granted - ) - user.auto_reload_failed_at = None - await db_session.commit() - logger.info( - "Auto-reload succeeded for user %s (+%s micro-USD)", - user_id, - amount, - ) - return - - # Not succeeded synchronously (e.g. requires_action / processing). - # Leave the row PENDING; the payment_intent webhook reconciles it. - await db_session.commit() - logger.info( - "Auto-reload PaymentIntent %s for user %s is %s; awaiting webhook.", - payment_intent_id, - user_id, - pi_status, - ) - - -async def _record_failure( - purchase_id: uuid.UUID, - user_id: str, - amount_micros: int, - *, - payment_intent_id: str | None, - reason: str | None, -) -> None: - """Mark the purchase FAILED, stamp the user, and notify them.""" - async with get_celery_session_maker()() as db_session: - purchase = ( - await db_session.execute( - select(CreditPurchase) - .where(CreditPurchase.id == purchase_id) - .with_for_update() - ) - ).scalar_one_or_none() - if purchase is not None and purchase.status == CreditPurchaseStatus.PENDING: - purchase.status = CreditPurchaseStatus.FAILED - if payment_intent_id: - purchase.stripe_payment_intent_id = payment_intent_id - - user = ( - ( - await db_session.execute( - select(User) - .where(User.id == uuid.UUID(user_id)) - .with_for_update(of=User) - ) - ) - .unique() - .scalar_one_or_none() - ) - if user is not None: - user.auto_reload_failed_at = datetime.now(UTC) - # Disable so a declined card doesn't get retried every debit; the - # user re-enables from settings (which clears the failure flag). - user.auto_reload_enabled = False - - await db_session.commit() - - try: - await NotificationService.auto_reload_failed.notify_auto_reload_failed( - session=db_session, - user_id=uuid.UUID(user_id), - amount_micros=amount_micros, - payment_intent_id=payment_intent_id, - reason=reason, - ) - except Exception: - logger.warning( - "Failed to create auto_reload_failed notification for user %s", - user_id, - exc_info=True, - ) diff --git a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py index 41e029a60..d38014124 100644 --- a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py @@ -668,52 +668,52 @@ async def _process_file_upload( # Import here to avoid circular dependencies from fastapi import HTTPException - from app.services.etl_credit_service import InsufficientCreditsError + from app.services.page_limit_service import PageLimitExceededError - # Check if this is an insufficient-credit error (either direct or - # wrapped in HTTPException) - credit_error: InsufficientCreditsError | None = None - if isinstance(e, InsufficientCreditsError): - credit_error = e + # Check if this is a page limit error (either direct or wrapped in HTTPException) + page_limit_error: PageLimitExceededError | None = None + if isinstance(e, PageLimitExceededError): + page_limit_error = e elif ( isinstance(e, HTTPException) and e.__cause__ - and isinstance(e.__cause__, InsufficientCreditsError) + and isinstance(e.__cause__, PageLimitExceededError) ): - # HTTPException wraps the original InsufficientCreditsError - credit_error = e.__cause__ - elif isinstance(e, HTTPException) and "credit" in str(e.detail).lower(): - # Fallback: HTTPException with credit message but no cause - credit_error = None # We don't have the details + # HTTPException wraps the original PageLimitExceededError + page_limit_error = e.__cause__ + elif isinstance(e, HTTPException) and "page limit" in str(e.detail).lower(): + # Fallback: HTTPException with page limit message but no cause + page_limit_error = None # We don't have the details - # For insufficient-credit errors, create a dedicated notification - if credit_error is not None: - error_message = str(credit_error) - # Create a dedicated insufficient credits notification + # For page limit errors, create a dedicated page_limit_exceeded notification + if page_limit_error is not None: + error_message = str(page_limit_error) + # Create a dedicated page limit exceeded notification try: # First, mark the processing notification as failed await session.refresh(notification) await NotificationService.document_processing.notify_processing_completed( session=session, notification=notification, - error_message="Insufficient credits", + error_message="Page limit exceeded", ) - # Then create a separate insufficient_credits notification for better UX - await NotificationService.insufficient_credits.notify_insufficient_credits( + # Then create a separate page_limit_exceeded notification for better UX + await NotificationService.page_limit.notify_page_limit_exceeded( session=session, user_id=UUID(user_id), document_name=filename, document_type="FILE", search_space_id=search_space_id, - balance_micros=credit_error.balance_micros, - required_micros=credit_error.required_micros, + pages_used=page_limit_error.pages_used, + pages_limit=page_limit_error.pages_limit, + pages_to_add=page_limit_error.pages_to_add, ) except Exception as notif_error: logger.error( - f"Failed to create insufficient credits notification: {notif_error!s}" + f"Failed to create page limit notification: {notif_error!s}" ) - elif isinstance(e, HTTPException) and "credit" in str(e.detail).lower(): + elif isinstance(e, HTTPException) and "page limit" in str(e.detail).lower(): # HTTPException with page limit message but no detailed cause error_message = str(e.detail) try: @@ -984,18 +984,18 @@ async def _process_file_with_document( # Import here to avoid circular dependencies from fastapi import HTTPException - from app.services.etl_credit_service import InsufficientCreditsError + from app.services.page_limit_service import PageLimitExceededError - # Check if this is an insufficient-credit error - credit_error: InsufficientCreditsError | None = None - if isinstance(e, InsufficientCreditsError): - credit_error = e + # Check if this is a page limit error + page_limit_error: PageLimitExceededError | None = None + if isinstance(e, PageLimitExceededError): + page_limit_error = e elif ( isinstance(e, HTTPException) and e.__cause__ - and isinstance(e.__cause__, InsufficientCreditsError) + and isinstance(e.__cause__, PageLimitExceededError) ): - credit_error = e.__cause__ + page_limit_error = e.__cause__ # Mark document as failed (shows error in UI via Zero) error_message = str(e)[:500] @@ -1006,27 +1006,28 @@ async def _process_file_with_document( f"[_process_file_with_document] Document {document_id} marked as failed: {error_message[:100]}" ) - # Handle insufficient-credit errors with dedicated notification - if credit_error is not None: + # Handle page limit errors with dedicated notification + if page_limit_error is not None: try: await session.refresh(notification) await NotificationService.document_processing.notify_processing_completed( session=session, notification=notification, - error_message="Insufficient credits", + error_message="Page limit exceeded", ) - await NotificationService.insufficient_credits.notify_insufficient_credits( + await NotificationService.page_limit.notify_page_limit_exceeded( session=session, user_id=UUID(user_id), document_name=filename, document_type="FILE", search_space_id=search_space_id, - balance_micros=credit_error.balance_micros, - required_micros=credit_error.required_micros, + pages_used=page_limit_error.pages_used, + pages_limit=page_limit_error.pages_limit, + pages_to_add=page_limit_error.pages_to_add, ) except Exception as notif_error: logger.error( - f"Failed to create insufficient credits notification: {notif_error!s}" + f"Failed to create page limit notification: {notif_error!s}" ) else: # Update notification on failure diff --git a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py new file mode 100644 index 000000000..8b311576e --- /dev/null +++ b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py @@ -0,0 +1,236 @@ +"""Celery tasks for podcast generation.""" + +import asyncio +import logging +import sys +from contextlib import asynccontextmanager + +from sqlalchemy import select + +from app.agents.podcaster.graph import graph as podcaster_graph +from app.agents.podcaster.state import State as PodcasterState +from app.celery_app import celery_app +from app.config import config as app_config +from app.db import Podcast, PodcastStatus +from app.services.billable_calls import ( + BillingSettlementError, + QuotaInsufficientError, + _resolve_agent_billing_for_search_space, + billable_call, +) +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task + +logger = logging.getLogger(__name__) + +if sys.platform.startswith("win"): + try: + asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) + except AttributeError: + logger.warning( + "WindowsProactorEventLoopPolicy is unavailable; async subprocess support may fail." + ) + + +# ============================================================================= +# Content-based podcast generation (for new-chat) +# ============================================================================= + + +@asynccontextmanager +async def _celery_billable_session(): + """Session factory used by billable_call inside the Celery worker loop.""" + async with get_celery_session_maker()() as session: + yield session + + +@celery_app.task(name="generate_content_podcast", bind=True) +def generate_content_podcast_task( + self, + podcast_id: int, + source_content: str, + search_space_id: int, + user_prompt: str | None = None, +) -> dict: + """ + Celery task to generate podcast from source content. + Updates existing podcast record created by the tool. + """ + try: + return run_async_celery_task( + lambda: _generate_content_podcast( + podcast_id, + source_content, + search_space_id, + user_prompt, + ) + ) + except Exception as e: + logger.error(f"Error generating content podcast: {e!s}") + try: + run_async_celery_task(lambda: _mark_podcast_failed(podcast_id)) + except Exception: + logger.exception("Failed to mark podcast %s as failed", podcast_id) + return {"status": "failed", "podcast_id": podcast_id} + + +async def _mark_podcast_failed(podcast_id: int) -> None: + """Mark a podcast as failed in the database.""" + async with get_celery_session_maker()() as session: + try: + result = await session.execute( + select(Podcast).filter(Podcast.id == podcast_id) + ) + podcast = result.scalars().first() + if podcast: + podcast.status = PodcastStatus.FAILED + await session.commit() + except Exception as e: + logger.error(f"Failed to mark podcast as failed: {e}") + + +async def _generate_content_podcast( + podcast_id: int, + source_content: str, + search_space_id: int, + user_prompt: str | None = None, +) -> dict: + """Generate content-based podcast and update existing record.""" + async with get_celery_session_maker()() as session: + result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id)) + podcast = result.scalars().first() + + if not podcast: + raise ValueError(f"Podcast {podcast_id} not found") + + try: + podcast.status = PodcastStatus.GENERATING + await session.commit() + + try: + ( + owner_user_id, + billing_tier, + base_model, + ) = await _resolve_agent_billing_for_search_space( + session, + search_space_id, + thread_id=podcast.thread_id, + ) + except ValueError as resolve_err: + logger.error( + "Podcast %s: cannot resolve billing for search_space=%s: %s", + podcast.id, + search_space_id, + resolve_err, + ) + podcast.status = PodcastStatus.FAILED + await session.commit() + return { + "status": "failed", + "podcast_id": podcast.id, + "reason": "billing_resolution_failed", + } + + graph_config = { + "configurable": { + "podcast_title": podcast.title, + "search_space_id": search_space_id, + "user_prompt": user_prompt, + } + } + + initial_state = PodcasterState( + source_content=source_content, + db_session=session, + ) + + try: + async with billable_call( + user_id=owner_user_id, + search_space_id=search_space_id, + billing_tier=billing_tier, + base_model=base_model, + quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS, + usage_type="podcast_generation", + call_details={ + "podcast_id": podcast.id, + "title": podcast.title, + "thread_id": podcast.thread_id, + }, + billable_session_factory=_celery_billable_session, + ): + graph_result = await podcaster_graph.ainvoke( + initial_state, config=graph_config + ) + except QuotaInsufficientError as exc: + logger.info( + "Podcast %s denied: out of premium credits " + "(used=%d/%d remaining=%d)", + podcast.id, + exc.used_micros, + exc.limit_micros, + exc.remaining_micros, + ) + podcast.status = PodcastStatus.FAILED + await session.commit() + return { + "status": "failed", + "podcast_id": podcast.id, + "reason": "premium_quota_exhausted", + } + except BillingSettlementError: + logger.exception( + "Podcast %s: premium billing settlement failed", + podcast.id, + ) + podcast.status = PodcastStatus.FAILED + await session.commit() + return { + "status": "failed", + "podcast_id": podcast.id, + "reason": "billing_settlement_failed", + } + + podcast_transcript = graph_result.get("podcast_transcript", []) + file_path = graph_result.get("final_podcast_file_path", "") + + serializable_transcript = [] + for entry in podcast_transcript: + if hasattr(entry, "speaker_id"): + serializable_transcript.append( + {"speaker_id": entry.speaker_id, "dialog": entry.dialog} + ) + else: + serializable_transcript.append( + { + "speaker_id": entry.get("speaker_id", 0), + "dialog": entry.get("dialog", ""), + } + ) + + podcast.podcast_transcript = serializable_transcript + podcast.file_location = file_path + podcast.status = PodcastStatus.READY + logger.info( + "Podcast %s: committing READY transcript_entries=%d file=%s", + podcast.id, + len(serializable_transcript), + file_path, + ) + await session.commit() + logger.info("Podcast %s: READY commit complete", podcast.id) + + logger.info(f"Successfully generated podcast: {podcast.id}") + + return { + "status": "ready", + "podcast_id": podcast.id, + "title": podcast.title, + "transcript_entries": len(serializable_transcript), + } + + except Exception as e: + logger.error(f"Error in _generate_content_podcast: {e!s}") + podcast.status = PodcastStatus.FAILED + await session.commit() + raise diff --git a/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py b/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py index f1ed6c6b3..ace6ef7ca 100644 --- a/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py @@ -1,4 +1,4 @@ -"""Reconcile pending Stripe credit purchases that might miss webhook fulfillment.""" +"""Reconcile pending Stripe purchases that might miss webhook fulfillment.""" from __future__ import annotations @@ -11,8 +11,10 @@ from stripe import StripeClient, StripeError from app.celery_app import celery_app from app.config import config from app.db import ( - CreditPurchase, - CreditPurchaseStatus, + PagePurchase, + PagePurchaseStatus, + PremiumTokenPurchase, + PremiumTokenPurchaseStatus, ) from app.routes import stripe_routes from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task @@ -30,14 +32,14 @@ def get_stripe_client() -> StripeClient | None: return StripeClient(config.STRIPE_SECRET_KEY) -@celery_app.task(name="reconcile_pending_stripe_credit_purchases") -def reconcile_pending_stripe_credit_purchases_task(): - """Recover paid credit purchases that were left pending due to missed webhook handling.""" - return run_async_celery_task(_reconcile_pending_credit_purchases) +@celery_app.task(name="reconcile_pending_stripe_page_purchases") +def reconcile_pending_stripe_page_purchases_task(): + """Recover paid purchases that were left pending due to missed webhook handling.""" + return run_async_celery_task(_reconcile_pending_page_purchases) -async def _reconcile_pending_credit_purchases() -> None: - """Reconcile stale pending credit purchases against Stripe source of truth. +async def _reconcile_pending_page_purchases() -> None: + """Reconcile stale pending page purchases against Stripe source of truth. Stripe retries webhook delivery automatically, but best practice is to add an application-level reconciliation path in case all retries fail or the endpoint @@ -55,12 +57,12 @@ async def _reconcile_pending_credit_purchases() -> None: pending_purchases = ( ( await db_session.execute( - select(CreditPurchase) + select(PagePurchase) .where( - CreditPurchase.status == CreditPurchaseStatus.PENDING, - CreditPurchase.created_at <= cutoff, + PagePurchase.status == PagePurchaseStatus.PENDING, + PagePurchase.created_at <= cutoff, ) - .order_by(CreditPurchase.created_at.asc()) + .order_by(PagePurchase.created_at.asc()) .limit(batch_size) ) ) @@ -70,13 +72,13 @@ async def _reconcile_pending_credit_purchases() -> None: if not pending_purchases: logger.debug( - "Stripe credit reconciliation found no pending purchases older than %s minutes.", + "Stripe reconciliation found no pending purchases older than %s minutes.", lookback_minutes, ) return logger.info( - "Stripe credit reconciliation checking %s pending purchases (cutoff=%s, batch=%s).", + "Stripe reconciliation checking %s pending purchases (cutoff=%s, batch=%s).", len(pending_purchases), lookback_minutes, batch_size, @@ -94,7 +96,7 @@ async def _reconcile_pending_credit_purchases() -> None: ) except StripeError: logger.exception( - "Stripe credit reconciliation failed to retrieve checkout session %s", + "Stripe reconciliation failed to retrieve checkout session %s", checkout_session_id, ) await db_session.rollback() @@ -105,24 +107,119 @@ async def _reconcile_pending_credit_purchases() -> None: try: if payment_status in {"paid", "no_payment_required"}: - await stripe_routes._fulfill_completed_credit_purchase( + await stripe_routes._fulfill_completed_purchase( db_session, checkout_session ) fulfilled_count += 1 elif session_status == "expired": - await stripe_routes._mark_credit_purchase_failed( + await stripe_routes._mark_purchase_failed( db_session, str(checkout_session.id) ) failed_count += 1 except Exception: logger.exception( - "Stripe credit reconciliation failed while processing checkout session %s", + "Stripe reconciliation failed while processing checkout session %s", checkout_session_id, ) await db_session.rollback() logger.info( - "Stripe credit reconciliation completed. fulfilled=%s failed=%s checked=%s", + "Stripe page reconciliation completed. fulfilled=%s failed=%s checked=%s", + fulfilled_count, + failed_count, + len(pending_purchases), + ) + + +@celery_app.task(name="reconcile_pending_stripe_token_purchases") +def reconcile_pending_stripe_token_purchases_task(): + """Recover paid token purchases that were left pending due to missed webhook handling.""" + return run_async_celery_task(_reconcile_pending_token_purchases) + + +async def _reconcile_pending_token_purchases() -> None: + """Reconcile stale pending token purchases against Stripe source of truth.""" + stripe_client = get_stripe_client() + if stripe_client is None: + return + + lookback_minutes = max(config.STRIPE_RECONCILIATION_LOOKBACK_MINUTES, 0) + batch_size = max(config.STRIPE_RECONCILIATION_BATCH_SIZE, 1) + cutoff = datetime.now(UTC) - timedelta(minutes=lookback_minutes) + + async with get_celery_session_maker()() as db_session: + pending_purchases = ( + ( + await db_session.execute( + select(PremiumTokenPurchase) + .where( + PremiumTokenPurchase.status + == PremiumTokenPurchaseStatus.PENDING, + PremiumTokenPurchase.created_at <= cutoff, + ) + .order_by(PremiumTokenPurchase.created_at.asc()) + .limit(batch_size) + ) + ) + .scalars() + .all() + ) + + if not pending_purchases: + logger.debug( + "Stripe token reconciliation found no pending purchases older than %s minutes.", + lookback_minutes, + ) + return + + logger.info( + "Stripe token reconciliation checking %s pending purchases (cutoff=%s, batch=%s).", + len(pending_purchases), + lookback_minutes, + batch_size, + ) + + fulfilled_count = 0 + failed_count = 0 + + for purchase in pending_purchases: + checkout_session_id = purchase.stripe_checkout_session_id + + try: + checkout_session = stripe_client.v1.checkout.sessions.retrieve( + checkout_session_id + ) + except StripeError: + logger.exception( + "Stripe token reconciliation failed to retrieve checkout session %s", + checkout_session_id, + ) + await db_session.rollback() + continue + + payment_status = getattr(checkout_session, "payment_status", None) + session_status = getattr(checkout_session, "status", None) + + try: + if payment_status in {"paid", "no_payment_required"}: + await stripe_routes._fulfill_completed_token_purchase( + db_session, checkout_session + ) + fulfilled_count += 1 + elif session_status == "expired": + await stripe_routes._mark_token_purchase_failed( + db_session, str(checkout_session.id) + ) + failed_count += 1 + except Exception: + logger.exception( + "Stripe token reconciliation failed while processing checkout session %s", + checkout_session_id, + ) + await db_session.rollback() + + logger.info( + "Stripe token reconciliation completed. fulfilled=%s failed=%s checked=%s", fulfilled_count, failed_count, len(pending_purchases), diff --git a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py index c6ce0b350..08f22140c 100644 --- a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py @@ -174,10 +174,11 @@ async def _generate_video_presentation( ) except QuotaInsufficientError as exc: logger.info( - "VideoPresentation %s denied: out of credits " - "(balance=%d remaining=%d)", + "VideoPresentation %s denied: out of premium credits " + "(used=%d/%d remaining=%d)", video_pres.id, - exc.balance_micros, + exc.used_micros, + exc.limit_micros, exc.remaining_micros, ) video_pres.status = VideoPresentationStatus.FAILED diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py index 1e6097e53..e33dca376 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py @@ -85,11 +85,11 @@ from app.tasks.chat.streaming.flows.shared.pre_stream_setup import ( setup_connector_and_firecrawl, ) from app.tasks.chat.streaming.flows.shared.premium_quota import ( - CreditReservation, - finalize_credit, - needs_credit_quota, - release_credit, - reserve_credit, + PremiumReservation, + finalize_premium, + needs_premium_quota, + release_premium, + reserve_premium, ) from app.tasks.chat.streaming.flows.shared.rate_limit_recovery import ( can_recover_provider_rate_limit, @@ -182,7 +182,7 @@ async def stream_new_chat( accumulator = start_turn() - premium_reservation: CreditReservation | None = None + premium_reservation: PremiumReservation | None = None busy_error_raised = False emit_stream_error = partial( @@ -259,8 +259,8 @@ async def stream_new_chat( yield streaming_service.format_done() return - if needs_credit_quota(agent_config, user_id): - premium_reservation = await reserve_credit( + if needs_premium_quota(agent_config, user_id): + premium_reservation = await reserve_premium( agent_config=agent_config, user_id=user_id, # type: ignore[arg-type] ) @@ -336,7 +336,7 @@ async def stream_new_chat( else: yield emit_stream_error( message=( - "Buy more credits to continue with this model, or " + "Buy more tokens to continue with this model, or " "switch to a free model" ), error_kind="premium_quota_exhausted", @@ -762,7 +762,7 @@ async def stream_new_chat( # sub-agent calls during a premium turn still contribute to the bill # (they're $0 in practice anyway). if premium_reservation is not None and user_id: - await finalize_credit( + await finalize_premium( reservation=premium_reservation, user_id=user_id, accumulator=accumulator, @@ -812,7 +812,7 @@ async def stream_new_chat( end_turn(str(chat_id)) if premium_reservation is not None and user_id: - await release_credit(reservation=premium_reservation, user_id=user_id) + await release_premium(reservation=premium_reservation, user_id=user_id) await close_session_and_clear_ai_responding(session, chat_id) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py index e1552e79e..6d0924850 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py @@ -64,11 +64,11 @@ from app.tasks.chat.streaming.flows.shared.pre_stream_setup import ( setup_connector_and_firecrawl, ) from app.tasks.chat.streaming.flows.shared.premium_quota import ( - CreditReservation, - finalize_credit, - needs_credit_quota, - release_credit, - reserve_credit, + PremiumReservation, + finalize_premium, + needs_premium_quota, + release_premium, + reserve_premium, ) from app.tasks.chat.streaming.flows.shared.rate_limit_recovery import ( can_recover_provider_rate_limit, @@ -144,7 +144,7 @@ async def stream_resume_chat( accumulator = start_turn() - premium_reservation: CreditReservation | None = None + premium_reservation: PremiumReservation | None = None busy_error_raised = False emit_stream_error = partial( @@ -212,8 +212,8 @@ async def stream_resume_chat( "[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0 ) - if needs_credit_quota(agent_config, user_id): - premium_reservation = await reserve_credit( + if needs_premium_quota(agent_config, user_id): + premium_reservation = await reserve_premium( agent_config=agent_config, user_id=user_id, # type: ignore[arg-type] ) @@ -285,7 +285,7 @@ async def stream_resume_chat( else: yield emit_stream_error( message=( - "Buy more credits to continue with this model, or " + "Buy more tokens to continue with this model, or " "switch to a free model" ), error_kind="premium_quota_exhausted", @@ -544,7 +544,7 @@ async def stream_resume_chat( return if premium_reservation is not None and user_id: - await finalize_credit( + await finalize_premium( reservation=premium_reservation, user_id=user_id, accumulator=accumulator, @@ -584,7 +584,7 @@ async def stream_resume_chat( end_turn(str(chat_id)) if premium_reservation is not None and user_id: - await release_credit(reservation=premium_reservation, user_id=user_id) + await release_premium(reservation=premium_reservation, user_id=user_id) await close_session_and_clear_ai_responding(session, chat_id) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/premium_quota.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/premium_quota.py index 232071394..6c08cb29f 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/shared/premium_quota.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/premium_quota.py @@ -1,12 +1,13 @@ -"""Credit wallet (USD micro-units) reserve / finalize / release lifecycle. +"""Premium credit (USD micro-units) reserve / finalize / release lifecycle. -Both ``stream_new_chat`` and ``stream_resume_chat`` reserve credits up front (so -a single LLM call can't run away with the budget), then finalize the actual -provider cost reported by LiteLLM when the turn completes successfully, or -release the reservation on the cancellation / interrupted-without-finalize paths. +Both ``stream_new_chat`` and ``stream_resume_chat`` reserve premium credits up +front (so a single LLM call can't run away with the budget), then finalize the +actual provider cost reported by LiteLLM when the turn completes successfully, +or release the reservation on the cancellation / interrupted-without-finalize +paths. -State is held by the orchestrator as a simple ``CreditReservation`` so -reservation, fallback-on-denied, finalize, and release can all be reasoned +State is held by the orchestrator as a simple ``PremiumReservation`` tuple +so reservation, fallback-on-denied, finalize, and release can all be reasoned about from one place. """ @@ -26,8 +27,8 @@ if TYPE_CHECKING: @dataclass -class CreditReservation: - """Active credit reservation for one turn. +class PremiumReservation: + """Active premium-credit reservation for one turn. ``request_id`` is the per-reservation idempotency key (also passed to ``finalize``/``release`` so racing branches resolve to the same row). @@ -40,15 +41,15 @@ class CreditReservation: allowed: bool -def needs_credit_quota(agent_config: AgentConfig | None, user_id: str | None) -> bool: +def needs_premium_quota(agent_config: AgentConfig | None, user_id: str | None) -> bool: return bool(agent_config is not None and user_id and agent_config.is_premium) -async def reserve_credit( +async def reserve_premium( *, agent_config: AgentConfig, user_id: str, -) -> CreditReservation: +) -> PremiumReservation: """Reserve estimated micros up front; returns the reservation handle.""" from app.services.token_quota_service import ( TokenQuotaService, @@ -67,22 +68,22 @@ async def reserve_credit( quota_reserve_tokens=agent_config.quota_reserve_tokens, ) async with shielded_async_session() as quota_session: - quota_result = await TokenQuotaService.credit_reserve( + quota_result = await TokenQuotaService.premium_reserve( db_session=quota_session, user_id=UUID(user_id), request_id=request_id, reserve_micros=reserve_amount_micros, ) - return CreditReservation( + return PremiumReservation( request_id=request_id, reserved_micros=reserve_amount_micros, allowed=quota_result.allowed, ) -async def finalize_credit( +async def finalize_premium( *, - reservation: CreditReservation, + reservation: PremiumReservation, user_id: str, accumulator: TokenAccumulator, ) -> None: @@ -95,7 +96,7 @@ async def finalize_credit( from app.services.token_quota_service import TokenQuotaService async with shielded_async_session() as quota_session: - await TokenQuotaService.credit_finalize( + await TokenQuotaService.premium_finalize( db_session=quota_session, user_id=UUID(user_id), request_id=reservation.request_id, @@ -104,15 +105,15 @@ async def finalize_credit( ) except Exception: logging.getLogger(__name__).warning( - "Failed to finalize credit quota for user %s", + "Failed to finalize premium quota for user %s", user_id, exc_info=True, ) -async def release_credit( +async def release_premium( *, - reservation: CreditReservation, + reservation: PremiumReservation, user_id: str, ) -> None: """Release the reservation on cancellation paths; never raises.""" @@ -120,12 +121,12 @@ async def release_credit( from app.services.token_quota_service import TokenQuotaService async with shielded_async_session() as quota_session: - await TokenQuotaService.credit_release( + await TokenQuotaService.premium_release( db_session=quota_session, user_id=UUID(user_id), reserved_micros=reservation.reserved_micros, ) except Exception: logging.getLogger(__name__).warning( - "Failed to release credit quota for user %s", user_id + "Failed to release premium quota for user %s", user_id ) diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/emission.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/emission.py index b21357b50..f1a1e9c37 100644 --- a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/emission.py +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/emission.py @@ -15,32 +15,22 @@ def iter_completion_emission_frames( out = ctx.tool_output payload = out if isinstance(out, dict) else {"result": out} yield ctx.emit_tool_output_card(payload) - status = out.get("status") if isinstance(out, dict) else None - title = out.get("title", "Podcast") if isinstance(out, dict) else "Podcast" - if status in ( - "awaiting_brief", - "awaiting_review", + if isinstance(out, dict) and out.get("status") in ( "pending", - "drafting", - "rendering", + "generating", + "processing", ): - # This line is persisted with the chat while the podcast keeps moving, - # so it must stay true after the lifecycle outgrows today's status. yield ctx.streaming_service.format_terminal_info( - f"Podcast created: {title}", + f"Podcast queued: {out.get('title', 'Podcast')}", "success", ) - elif status in ("ready", "success"): + elif isinstance(out, dict) and out.get("status") in ("ready", "success"): yield ctx.streaming_service.format_terminal_info( - f"Podcast generated successfully: {title}", + f"Podcast generated successfully: {out.get('title', 'Podcast')}", "success", ) - elif status in ("failed", "error"): - error_msg = ( - out.get("error", "Unknown error") - if isinstance(out, dict) - else "Unknown error" - ) + elif isinstance(out, dict) and out.get("status") in ("failed", "error"): + error_msg = out.get("error", "Unknown error") yield ctx.streaming_service.format_terminal_info( f"Podcast generation failed: {error_msg}", "error", diff --git a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/thinking.py b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/thinking.py index fe8f9cfb7..5cf78ea72 100644 --- a/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/thinking.py +++ b/surfsense_backend/app/tasks/chat/streaming/handlers/tools/deliverables/generate_podcast/thinking.py @@ -24,11 +24,11 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking d.get("source_content", "") if isinstance(tool_input, dict) else "" ) return ToolStartThinking( - title="Preparing podcast", + title="Generating podcast", items=[ f"Title: {podcast_title}", f"Content: {content_len:,} characters", - "Proposing brief (language, voices, length)...", + "Preparing audio generation...", ], ) @@ -50,19 +50,17 @@ def resolve_completed_thinking( if isinstance(tool_output, dict) else "Podcast" ) - if podcast_status in ( - "awaiting_brief", - "awaiting_review", - "pending", - "drafting", - "rendering", - ): - # Persisted with the chat while the podcast keeps moving, so the copy - # must stay true after the lifecycle outgrows today's status. + if podcast_status in ("pending", "generating", "processing"): completed = [ f"Title: {podcast_title}", - "Podcast created", - "Review and progress continue on the podcast card", + "Podcast generation started", + "Processing in background...", + ] + elif podcast_status == "already_generating": + completed = [ + f"Title: {podcast_title}", + "Podcast already in progress", + "Please wait for it to complete", ] elif podcast_status in ("failed", "error"): error_msg = ( @@ -81,4 +79,4 @@ def resolve_completed_thinking( ] else: completed = items - return ("Preparing podcast", completed) + return ("Generating podcast", completed) diff --git a/surfsense_backend/app/tasks/connector_indexers/dropbox_indexer.py b/surfsense_backend/app/tasks/connector_indexers/dropbox_indexer.py index 9bf290d85..7cd3e1613 100644 --- a/surfsense_backend/app/tasks/connector_indexers/dropbox_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/dropbox_indexer.py @@ -28,7 +28,7 @@ from app.indexing_pipeline.connector_document import ConnectorDocument from app.indexing_pipeline.document_hashing import compute_identifier_hash from app.indexing_pipeline.exceptions import safe_exception_message from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService -from app.services.etl_credit_service import EtlCreditService +from app.services.page_limit_service import PageLimitService from app.services.task_logging_service import TaskLoggingService from app.tasks.connector_indexers.base import ( check_document_by_unique_identifier, @@ -423,8 +423,9 @@ async def _index_full_scan( }, ) - etl_credit_service = EtlCreditService(session) - available_micros = await etl_credit_service.get_available_micros(user_id) + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used batch_estimated_pages = 0 page_limit_reached = False @@ -466,17 +467,13 @@ async def _index_full_scan( skipped += 1 continue - file_pages = EtlCreditService.estimate_pages_from_metadata( + file_pages = PageLimitService.estimate_pages_from_metadata( file.get("name", ""), file.get("size") ) - if ( - available_micros is not None - and EtlCreditService.pages_to_micros(batch_estimated_pages + file_pages) - > available_micros - ): + if batch_estimated_pages + file_pages > remaining_quota: if not page_limit_reached: logger.warning( - "Insufficient credits during Dropbox full scan, " + "Page limit reached during Dropbox full scan, " "skipping remaining files" ) page_limit_reached = True @@ -501,7 +498,9 @@ async def _index_full_scan( pages_to_deduct = max( 1, batch_estimated_pages * batch_indexed // len(files_to_download) ) - await etl_credit_service.charge_credits(user_id, pages_to_deduct) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) indexed = renamed_count + batch_indexed logger.info( @@ -524,8 +523,9 @@ async def _index_selected_files( vision_llm=None, ) -> tuple[int, int, int, list[str]]: """Index user-selected files using the parallel pipeline.""" - etl_credit_service = EtlCreditService(session) - available_micros = await etl_credit_service.get_available_micros(user_id) + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used batch_estimated_pages = 0 files_to_download: list[dict] = [] @@ -560,16 +560,12 @@ async def _index_selected_files( skipped += 1 continue - file_pages = EtlCreditService.estimate_pages_from_metadata( + file_pages = PageLimitService.estimate_pages_from_metadata( file.get("name", ""), file.get("size") ) - if ( - available_micros is not None - and EtlCreditService.pages_to_micros(batch_estimated_pages + file_pages) - > available_micros - ): + if batch_estimated_pages + file_pages > remaining_quota: display = file_name or file_path - errors.append(f"File '{display}': insufficient credits") + errors.append(f"File '{display}': page limit would be exceeded") continue batch_estimated_pages += file_pages @@ -590,7 +586,9 @@ async def _index_selected_files( pages_to_deduct = max( 1, batch_estimated_pages * batch_indexed // len(files_to_download) ) - await etl_credit_service.charge_credits(user_id, pages_to_deduct) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) return renamed_count + batch_indexed, skipped, unsupported_count, errors diff --git a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py index 37de66ffd..b76f84bac 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py @@ -41,7 +41,7 @@ from app.indexing_pipeline.indexing_pipeline_service import ( PlaceholderInfo, ) from app.services.composio_service import ComposioService -from app.services.etl_credit_service import EtlCreditService +from app.services.page_limit_service import PageLimitService from app.services.task_logging_service import TaskLoggingService from app.tasks.connector_indexers.base import ( check_document_by_unique_identifier, @@ -555,11 +555,11 @@ async def _process_single_file( return 1, 0, 0 return 0, 1, 0 - etl_credit_service = EtlCreditService(session) - estimated_pages = EtlCreditService.estimate_pages_from_metadata( + page_limit_service = PageLimitService(session) + estimated_pages = PageLimitService.estimate_pages_from_metadata( file_name, file.get("size") ) - await etl_credit_service.check_credits(user_id, estimated_pages) + await page_limit_service.check_page_limit(user_id, estimated_pages) markdown, drive_metadata, error = await download_and_extract_content( drive_client, file, vision_llm=vision_llm @@ -602,7 +602,9 @@ async def _process_single_file( continue await pipeline.index(document, connector_doc) - await etl_credit_service.charge_credits(user_id, estimated_pages) + await page_limit_service.update_page_usage( + user_id, estimated_pages, allow_exceed=True + ) logger.info(f"Successfully indexed Google Drive file: {file_name}") return 1, 0, 0 @@ -711,8 +713,9 @@ async def _index_selected_files( Returns (indexed_count, skipped_count, unsupported_count, errors). """ - etl_credit_service = EtlCreditService(session) - available_micros = await etl_credit_service.get_available_micros(user_id) + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used batch_estimated_pages = 0 files_to_download: list[dict] = [] @@ -738,16 +741,12 @@ async def _index_selected_files( skipped += 1 continue - file_pages = EtlCreditService.estimate_pages_from_metadata( + file_pages = PageLimitService.estimate_pages_from_metadata( file.get("name", ""), file.get("size") ) - if ( - available_micros is not None - and EtlCreditService.pages_to_micros(batch_estimated_pages + file_pages) - > available_micros - ): + if batch_estimated_pages + file_pages > remaining_quota: display = file_name or file_id - errors.append(f"File '{display}': insufficient credits") + errors.append(f"File '{display}': page limit would be exceeded") continue batch_estimated_pages += file_pages @@ -776,7 +775,9 @@ async def _index_selected_files( pages_to_deduct = max( 1, batch_estimated_pages * batch_indexed // len(files_to_download) ) - await etl_credit_service.charge_credits(user_id, pages_to_deduct) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) return renamed_count + batch_indexed, skipped, unsupported_count, errors @@ -819,8 +820,9 @@ async def _index_full_scan( # ------------------------------------------------------------------ # Phase 1 (serial): collect files, run skip checks, track renames # ------------------------------------------------------------------ - etl_credit_service = EtlCreditService(session) - available_micros = await etl_credit_service.get_available_micros(user_id) + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used batch_estimated_pages = 0 page_limit_reached = False @@ -875,19 +877,13 @@ async def _index_full_scan( skipped += 1 continue - file_pages = EtlCreditService.estimate_pages_from_metadata( + file_pages = PageLimitService.estimate_pages_from_metadata( file.get("name", ""), file.get("size") ) - if ( - available_micros is not None - and EtlCreditService.pages_to_micros( - batch_estimated_pages + file_pages - ) - > available_micros - ): + if batch_estimated_pages + file_pages > remaining_quota: if not page_limit_reached: logger.warning( - "Insufficient credits during Google Drive full scan, " + "Page limit reached during Google Drive full scan, " "skipping remaining files" ) page_limit_reached = True @@ -942,7 +938,9 @@ async def _index_full_scan( pages_to_deduct = max( 1, batch_estimated_pages * batch_indexed // len(files_to_download) ) - await etl_credit_service.charge_credits(user_id, pages_to_deduct) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) indexed = renamed_count + batch_indexed logger.info( @@ -998,8 +996,9 @@ async def _index_with_delta_sync( # ------------------------------------------------------------------ # Phase 1 (serial): handle removals, collect files for download # ------------------------------------------------------------------ - etl_credit_service = EtlCreditService(session) - available_micros = await etl_credit_service.get_available_micros(user_id) + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used batch_estimated_pages = 0 page_limit_reached = False @@ -1035,17 +1034,13 @@ async def _index_with_delta_sync( skipped += 1 continue - file_pages = EtlCreditService.estimate_pages_from_metadata( + file_pages = PageLimitService.estimate_pages_from_metadata( file.get("name", ""), file.get("size") ) - if ( - available_micros is not None - and EtlCreditService.pages_to_micros(batch_estimated_pages + file_pages) - > available_micros - ): + if batch_estimated_pages + file_pages > remaining_quota: if not page_limit_reached: logger.warning( - "Insufficient credits during Google Drive delta sync, " + "Page limit reached during Google Drive delta sync, " "skipping remaining files" ) page_limit_reached = True @@ -1084,7 +1079,9 @@ async def _index_with_delta_sync( pages_to_deduct = max( 1, batch_estimated_pages * batch_indexed // len(files_to_download) ) - await etl_credit_service.charge_credits(user_id, pages_to_deduct) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) indexed = renamed_count + batch_indexed logger.info( diff --git a/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py b/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py index 1a2d4b967..1cd92dcf8 100644 --- a/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py @@ -33,7 +33,7 @@ from app.db import ( from app.indexing_pipeline.connector_document import ConnectorDocument from app.indexing_pipeline.document_hashing import compute_identifier_hash from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService -from app.services.etl_credit_service import EtlCreditService, InsufficientCreditsError +from app.services.page_limit_service import PageLimitExceededError, PageLimitService from app.services.task_logging_service import TaskLoggingService from app.tasks.celery_tasks import get_celery_session_maker from app.utils.document_versioning import create_version_snapshot @@ -46,38 +46,38 @@ from .base import ( HeartbeatCallbackType = Callable[[int], Awaitable[None]] -def _estimate_pages_safe(etl_credit_service: EtlCreditService, file_path: str) -> int: +def _estimate_pages_safe(page_limit_service: PageLimitService, file_path: str) -> int: """Estimate page count with a file-size fallback.""" try: - return etl_credit_service.estimate_pages_before_processing(file_path) + return page_limit_service.estimate_pages_before_processing(file_path) except Exception: file_size = os.path.getsize(file_path) return max(1, file_size // (80 * 1024)) -async def _check_credits_or_skip( - etl_credit_service: EtlCreditService, +async def _check_page_limit_or_skip( + page_limit_service: PageLimitService, user_id: str, file_path: str, page_multiplier: int = 1, ) -> tuple[int, int]: - """Estimate pages and check credit; raises InsufficientCreditsError if unaffordable. + """Estimate pages and check the limit; raises PageLimitExceededError if over quota. Returns (estimated_pages, billable_pages). """ - estimated = _estimate_pages_safe(etl_credit_service, file_path) + estimated = _estimate_pages_safe(page_limit_service, file_path) billable = estimated * page_multiplier - await etl_credit_service.check_credits(user_id, billable) + await page_limit_service.check_page_limit(user_id, billable) return estimated, billable def _compute_final_pages( - etl_credit_service: EtlCreditService, + page_limit_service: PageLimitService, estimated_pages: int, content_length: int, ) -> int: """Return the final page count as max(estimated, actual).""" - actual = etl_credit_service.estimate_pages_from_content_length(content_length) + actual = page_limit_service.estimate_pages_from_content_length(content_length) return max(estimated_pages, actual) @@ -635,7 +635,7 @@ async def index_local_folder( skipped_count = 0 failed_count = 0 - etl_credit_service = EtlCreditService(session) + page_limit_service = PageLimitService(session) # ================================================================ # PHASE 1: Pre-filter files (mtime / content-hash), version changed @@ -694,12 +694,12 @@ async def index_local_folder( continue try: - estimated_pages, _billable = await _check_credits_or_skip( - etl_credit_service, user_id, file_path_abs + estimated_pages, _billable = await _check_page_limit_or_skip( + page_limit_service, user_id, file_path_abs ) - except InsufficientCreditsError: + except PageLimitExceededError: logger.warning( - f"Insufficient credits, skipping: {file_path_abs}" + f"Page limit exceeded, skipping: {file_path_abs}" ) failed_count += 1 continue @@ -730,12 +730,12 @@ async def index_local_folder( await create_version_snapshot(session, existing_document) else: try: - estimated_pages, _billable = await _check_credits_or_skip( - etl_credit_service, user_id, file_path_abs + estimated_pages, _billable = await _check_page_limit_or_skip( + page_limit_service, user_id, file_path_abs ) - except InsufficientCreditsError: + except PageLimitExceededError: logger.warning( - f"Insufficient credits, skipping: {file_path_abs}" + f"Page limit exceeded, skipping: {file_path_abs}" ) failed_count += 1 continue @@ -858,9 +858,11 @@ async def index_local_folder( est = mtime_info.get("estimated_pages", 1) content_len = mtime_info.get("content_length", 0) final_pages = _compute_final_pages( - etl_credit_service, est, content_len + page_limit_service, est, content_len + ) + await page_limit_service.update_page_usage( + user_id, final_pages, allow_exceed=True ) - await etl_credit_service.charge_credits(user_id, final_pages) else: failed_count += 1 @@ -1070,13 +1072,13 @@ async def _index_single_file( await session.commit() return 0, 0, None - etl_credit_service = EtlCreditService(session) + page_limit_service = PageLimitService(session) try: - estimated_pages, _billable = await _check_credits_or_skip( - etl_credit_service, user_id, str(full_path) + estimated_pages, _billable = await _check_page_limit_or_skip( + page_limit_service, user_id, str(full_path) ) - except InsufficientCreditsError as e: - return 0, 1, f"Insufficient credits: {e}" + except PageLimitExceededError as e: + return 0, 1, f"Page limit exceeded: {e}" try: content, content_hash = await _compute_file_content_hash( @@ -1140,9 +1142,11 @@ async def _index_single_file( if indexed: final_pages = _compute_final_pages( - etl_credit_service, estimated_pages, len(content) + page_limit_service, estimated_pages, len(content) + ) + await page_limit_service.update_page_usage( + user_id, final_pages, allow_exceed=True ) - await etl_credit_service.charge_credits(user_id, final_pages) await task_logger.log_task_success( log_entry, f"Single file indexed: {rel_path}", @@ -1295,7 +1299,7 @@ async def index_uploaded_files( await _set_indexing_flag(session, root_folder_id) - etl_credit_service = EtlCreditService(session) + page_limit_service = PageLimitService(session) pipeline = IndexingPipelineService(session) vision_llm_instance = None @@ -1341,14 +1345,14 @@ async def index_uploaded_files( continue try: - estimated_pages, _billable_pages = await _check_credits_or_skip( - etl_credit_service, + estimated_pages, _billable_pages = await _check_page_limit_or_skip( + page_limit_service, user_id, temp_path, page_multiplier=mode.page_multiplier, ) - except InsufficientCreditsError: - logger.warning(f"Insufficient credits, skipping: {relative_path}") + except PageLimitExceededError: + logger.warning(f"Page limit exceeded, skipping: {relative_path}") failed_count += 1 continue @@ -1421,10 +1425,12 @@ async def index_uploaded_files( if DocumentStatus.is_state(db_doc.status, DocumentStatus.READY): indexed_count += 1 final_pages = _compute_final_pages( - etl_credit_service, estimated_pages, len(content) + page_limit_service, estimated_pages, len(content) ) final_billable = final_pages * mode.page_multiplier - await etl_credit_service.charge_credits(user_id, final_billable) + await page_limit_service.update_page_usage( + user_id, final_billable, allow_exceed=True + ) else: failed_count += 1 diff --git a/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py index 1a83551fb..3fd8a79f2 100644 --- a/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py @@ -28,7 +28,7 @@ from app.indexing_pipeline.connector_document import ConnectorDocument from app.indexing_pipeline.document_hashing import compute_identifier_hash from app.indexing_pipeline.exceptions import safe_exception_message from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService -from app.services.etl_credit_service import EtlCreditService +from app.services.page_limit_service import PageLimitService from app.services.task_logging_service import TaskLoggingService from app.tasks.connector_indexers.base import ( check_document_by_unique_identifier, @@ -318,8 +318,9 @@ async def _index_selected_files( vision_llm=None, ) -> tuple[int, int, int, list[str]]: """Index user-selected files using the parallel pipeline.""" - etl_credit_service = EtlCreditService(session) - available_micros = await etl_credit_service.get_available_micros(user_id) + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used batch_estimated_pages = 0 files_to_download: list[dict] = [] @@ -345,16 +346,12 @@ async def _index_selected_files( skipped += 1 continue - file_pages = EtlCreditService.estimate_pages_from_metadata( + file_pages = PageLimitService.estimate_pages_from_metadata( file.get("name", ""), file.get("size") ) - if ( - available_micros is not None - and EtlCreditService.pages_to_micros(batch_estimated_pages + file_pages) - > available_micros - ): + if batch_estimated_pages + file_pages > remaining_quota: display = file_name or file_id - errors.append(f"File '{display}': insufficient credits") + errors.append(f"File '{display}': page limit would be exceeded") continue batch_estimated_pages += file_pages @@ -375,7 +372,9 @@ async def _index_selected_files( pages_to_deduct = max( 1, batch_estimated_pages * batch_indexed // len(files_to_download) ) - await etl_credit_service.charge_credits(user_id, pages_to_deduct) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) return renamed_count + batch_indexed, skipped, unsupported_count, errors @@ -414,8 +413,9 @@ async def _index_full_scan( }, ) - etl_credit_service = EtlCreditService(session) - available_micros = await etl_credit_service.get_available_micros(user_id) + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used batch_estimated_pages = 0 page_limit_reached = False @@ -448,17 +448,13 @@ async def _index_full_scan( skipped += 1 continue - file_pages = EtlCreditService.estimate_pages_from_metadata( + file_pages = PageLimitService.estimate_pages_from_metadata( file.get("name", ""), file.get("size") ) - if ( - available_micros is not None - and EtlCreditService.pages_to_micros(batch_estimated_pages + file_pages) - > available_micros - ): + if batch_estimated_pages + file_pages > remaining_quota: if not page_limit_reached: logger.warning( - "Insufficient credits during OneDrive full scan, " + "Page limit reached during OneDrive full scan, " "skipping remaining files" ) page_limit_reached = True @@ -483,7 +479,9 @@ async def _index_full_scan( pages_to_deduct = max( 1, batch_estimated_pages * batch_indexed // len(files_to_download) ) - await etl_credit_service.charge_credits(user_id, pages_to_deduct) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) indexed = renamed_count + batch_indexed logger.info( @@ -534,8 +532,9 @@ async def _index_with_delta_sync( logger.info(f"Processing {len(changes)} delta changes") - etl_credit_service = EtlCreditService(session) - available_micros = await etl_credit_service.get_available_micros(user_id) + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used batch_estimated_pages = 0 page_limit_reached = False @@ -572,17 +571,13 @@ async def _index_with_delta_sync( skipped += 1 continue - file_pages = EtlCreditService.estimate_pages_from_metadata( + file_pages = PageLimitService.estimate_pages_from_metadata( change.get("name", ""), change.get("size") ) - if ( - available_micros is not None - and EtlCreditService.pages_to_micros(batch_estimated_pages + file_pages) - > available_micros - ): + if batch_estimated_pages + file_pages > remaining_quota: if not page_limit_reached: logger.warning( - "Insufficient credits during OneDrive delta sync, " + "Page limit reached during OneDrive delta sync, " "skipping remaining files" ) page_limit_reached = True @@ -607,7 +602,9 @@ async def _index_with_delta_sync( pages_to_deduct = max( 1, batch_estimated_pages * batch_indexed // len(files_to_download) ) - await etl_credit_service.charge_credits(user_id, pages_to_deduct) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) indexed = renamed_count + batch_indexed logger.info( diff --git a/surfsense_backend/app/tasks/document_processors/file_processors.py b/surfsense_backend/app/tasks/document_processors/file_processors.py index a646b7aa6..f6929b87c 100644 --- a/surfsense_backend/app/tasks/document_processors/file_processors.py +++ b/surfsense_backend/app/tasks/document_processors/file_processors.py @@ -79,10 +79,10 @@ async def _notify( # --------------------------------------------------------------------------- -def _estimate_pages_safe(etl_credit_service, file_path: str) -> int: +def _estimate_pages_safe(page_limit_service, file_path: str) -> int: """Estimate page count with a file-size fallback.""" try: - return etl_credit_service.estimate_pages_before_processing(file_path) + return page_limit_service.estimate_pages_before_processing(file_path) except Exception: file_size = os.path.getsize(file_path) return max(1, file_size // (80 * 1024)) @@ -185,14 +185,11 @@ async def _process_document_upload(ctx: _ProcessingContext) -> Document | None: """Route a document file to the configured ETL service via the unified pipeline.""" from app.etl_pipeline.etl_document import EtlRequest, ProcessingMode from app.etl_pipeline.etl_pipeline_service import EtlPipelineService - from app.services.etl_credit_service import ( - EtlCreditService, - InsufficientCreditsError, - ) + from app.services.page_limit_service import PageLimitExceededError, PageLimitService mode = ProcessingMode.coerce(ctx.processing_mode) - etl_credit_service = EtlCreditService(ctx.session) - estimated_pages = _estimate_pages_safe(etl_credit_service, ctx.file_path) + page_limit_service = PageLimitService(ctx.session) + estimated_pages = _estimate_pages_safe(page_limit_service, ctx.file_path) billable_pages = estimated_pages * mode.page_multiplier await ctx.task_logger.log_task_progress( @@ -207,16 +204,16 @@ async def _process_document_upload(ctx: _ProcessingContext) -> Document | None: ) try: - await etl_credit_service.check_credits(ctx.user_id, billable_pages) - except InsufficientCreditsError as e: + await page_limit_service.check_page_limit(ctx.user_id, billable_pages) + except PageLimitExceededError as e: await ctx.task_logger.log_task_failure( ctx.log_entry, - f"Insufficient credits before processing: {ctx.filename}", + f"Page limit exceeded before processing: {ctx.filename}", str(e), { - "error_type": "InsufficientCredits", - "balance_micros": e.balance_micros, - "required_micros": e.required_micros, + "error_type": "PageLimitExceeded", + "pages_used": e.pages_used, + "pages_limit": e.pages_limit, "estimated_pages": estimated_pages, "billable_pages": billable_pages, "processing_mode": mode.value, @@ -262,7 +259,9 @@ async def _process_document_upload(ctx: _ProcessingContext) -> Document | None: ) if result: - await etl_credit_service.charge_credits(ctx.user_id, billable_pages) + await page_limit_service.update_page_usage( + ctx.user_id, billable_pages, allow_exceed=True + ) if ctx.connector: await update_document_from_connector(result, ctx.connector, ctx.session) await ctx.task_logger.log_task_success( @@ -338,11 +337,11 @@ async def process_file_in_background( except Exception as e: await session.rollback() - from app.services.etl_credit_service import InsufficientCreditsError + from app.services.page_limit_service import PageLimitExceededError - if isinstance(e, InsufficientCreditsError): + if isinstance(e, PageLimitExceededError): error_message = str(e) - elif isinstance(e, HTTPException) and "credit" in str(e.detail).lower(): + elif isinstance(e, HTTPException) and "page limit" in str(e.detail).lower(): error_message = str(e.detail) else: error_message = f"Failed to process file: {filename}" @@ -415,12 +414,12 @@ async def _extract_file_content( ) if category == FileCategory.DOCUMENT: - from app.services.etl_credit_service import EtlCreditService + from app.services.page_limit_service import PageLimitService - etl_credit_service = EtlCreditService(session) - estimated_pages = _estimate_pages_safe(etl_credit_service, file_path) + page_limit_service = PageLimitService(session) + estimated_pages = _estimate_pages_safe(page_limit_service, file_path) billable_pages = estimated_pages * mode.page_multiplier - await etl_credit_service.check_credits(user_id, billable_pages) + await page_limit_service.check_page_limit(user_id, billable_pages) # Vision LLM is provided to the ETL pipeline for any file category # when the operator opts in. Image files run through it directly; @@ -525,10 +524,12 @@ async def process_file_in_background_with_document( ) if billable_pages > 0: - from app.services.etl_credit_service import EtlCreditService + from app.services.page_limit_service import PageLimitService - etl_credit_service = EtlCreditService(session) - await etl_credit_service.charge_credits(user_id, billable_pages) + page_limit_service = PageLimitService(session) + await page_limit_service.update_page_usage( + user_id, billable_pages, allow_exceed=True + ) await task_logger.log_task_success( log_entry, @@ -546,11 +547,11 @@ async def process_file_in_background_with_document( except Exception as e: await session.rollback() - from app.services.etl_credit_service import InsufficientCreditsError + from app.services.page_limit_service import PageLimitExceededError - if isinstance(e, InsufficientCreditsError): + if isinstance(e, PageLimitExceededError): error_message = str(e) - elif isinstance(e, HTTPException) and "credit" in str(e.detail).lower(): + elif isinstance(e, HTTPException) and "page limit" in str(e.detail).lower(): error_message = str(e.detail) else: error_message = f"Failed to process file: {filename}" diff --git a/surfsense_backend/app/zero_publication.py b/surfsense_backend/app/zero_publication.py index b14ee14d1..d2755d0a1 100644 --- a/surfsense_backend/app/zero_publication.py +++ b/surfsense_backend/app/zero_publication.py @@ -38,7 +38,10 @@ DOCUMENT_COLS = [ USER_COLS = [ "id", - "credit_micros_balance", + "pages_limit", + "pages_used", + "premium_credit_micros_limit", + "premium_credit_micros_used", ] AUTOMATION_RUN_COLS = [ @@ -52,22 +55,6 @@ AUTOMATION_RUN_COLS = [ "created_at", ] -# Enough to drive the lifecycle UI by push: status, the reviewable brief, and -# its version. The bulky source_content and transcript are deliberately excluded -# and fetched over REST when a gate opens. -PODCAST_COLS = [ - "id", - "title", - "status", - "spec", - "spec_version", - "duration_seconds", - "error", - "search_space_id", - "thread_id", - "created_at", -] - ZERO_PUBLICATION: Mapping[str, Sequence[str] | None] = { "notifications": None, "documents": DOCUMENT_COLS, @@ -78,7 +65,6 @@ ZERO_PUBLICATION: Mapping[str, Sequence[str] | None] = { "chat_session_state": None, "user": USER_COLS, "automation_runs": AUTOMATION_RUN_COLS, - "podcasts": PODCAST_COLS, } @@ -86,15 +72,18 @@ def _quote_identifier(identifier: str) -> str: return '"' + identifier.replace('"', '""') + '"' -def _table_columns(conn: Connection, table: str) -> set[str]: - rows = conn.execute( - text( - "SELECT column_name FROM information_schema.columns " - "WHERE table_schema = current_schema() AND table_name = :table" - ), - {"table": table}, - ).fetchall() - return {row[0] for row in rows} +def _column_exists(conn: Connection, table: str, column: str) -> bool: + return ( + conn.execute( + text( + "SELECT 1 FROM information_schema.columns " + "WHERE table_schema = current_schema() " + "AND table_name = :table AND column_name = :column" + ), + {"table": table, "column": column}, + ).fetchone() + is not None + ) def _expected_columns(conn: Connection, table: str) -> list[str] | None: @@ -103,39 +92,17 @@ def _expected_columns(conn: Connection, table: str) -> list[str] | None: return None expected = list(columns) - if table in {"documents", "user", "podcasts"} and "_0_version" in _table_columns( - conn, table - ): + if table in {"documents", "user"} and _column_exists(conn, table, "_0_version"): expected.append("_0_version") return expected -def _format_table_entry(conn: Connection, table: str) -> str | None: - """Render one SET TABLE entry, or ``None`` if the table isn't ready. - - Historical migrations (e.g. 155/156) call ``apply_publication`` while the - schema is still mid-history, before later migrations add columns that the - canonical shape references. A table is only published once it exists AND - every canonical column exists; otherwise it is omitted entirely and a later - reconcile migration (e.g. 159) picks it up once its columns land. Partial - column lists are deliberately avoided: publishing a column early would - block later ``ALTER COLUMN ... TYPE`` migrations on it (Postgres forbids - retyping columns a publication depends on). ``verify_publication`` remains - strict against the unfiltered canonical shape. - """ - - actual = _table_columns(conn, table) - if not actual: - return None - - table_sql = _quote_identifier(table) +def _format_table_entry(conn: Connection, table: str) -> str: columns = _expected_columns(conn, table) + table_sql = _quote_identifier(table) if columns is None: return table_sql - if any(column not in actual for column in columns): - return None - column_sql = ", ".join(_quote_identifier(column) for column in columns) return f"{table_sql} ({column_sql})" @@ -143,8 +110,9 @@ def _format_table_entry(conn: Connection, table: str) -> str | None: def build_set_table_sql(conn: Connection) -> str: """Build the canonical plain SET TABLE statement for Zero's event triggers.""" - entries = [_format_table_entry(conn, table) for table in ZERO_PUBLICATION] - table_list = ", ".join(entry for entry in entries if entry is not None) + table_list = ", ".join( + _format_table_entry(conn, table) for table in ZERO_PUBLICATION + ) return f"ALTER PUBLICATION {_quote_identifier(PUBLICATION_NAME)} SET TABLE {table_list}" diff --git a/surfsense_backend/pyproject.toml b/surfsense_backend/pyproject.toml index ff43f6a97..16d46445c 100644 --- a/surfsense_backend/pyproject.toml +++ b/surfsense_backend/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "surf-new-backend" -version = "0.0.28" +version = "0.0.27" description = "SurfSense Backend" requires-python = ">=3.12" dependencies = [ diff --git a/surfsense_backend/tests/integration/document_upload/conftest.py b/surfsense_backend/tests/integration/document_upload/conftest.py index 812140be3..13e3ab59c 100644 --- a/surfsense_backend/tests/integration/document_upload/conftest.py +++ b/surfsense_backend/tests/integration/document_upload/conftest.py @@ -204,34 +204,32 @@ async def _cleanup_documents( # --------------------------------------------------------------------------- -# Credit-wallet helpers (direct DB for setup, API for verification) +# Page-limit helpers (direct DB for setup, API for verification) # --------------------------------------------------------------------------- -async def _get_user_credit(email: str) -> tuple[int, int]: +async def _get_user_page_usage(email: str) -> tuple[int, int]: conn = await asyncpg.connect(_ASYNCPG_URL) try: row = await conn.fetchrow( - "SELECT credit_micros_balance, credit_micros_reserved " - 'FROM "user" WHERE email = $1', + 'SELECT pages_used, pages_limit FROM "user" WHERE email = $1', email, ) assert row is not None, f"User {email!r} not found in database" - return row["credit_micros_balance"], row["credit_micros_reserved"] + return row["pages_used"], row["pages_limit"] finally: await conn.close() -async def _set_user_credit( - email: str, *, balance_micros: int, reserved_micros: int = 0 +async def _set_user_page_limits( + email: str, *, pages_used: int, pages_limit: int ) -> None: conn = await asyncpg.connect(_ASYNCPG_URL) try: await conn.execute( - 'UPDATE "user" SET credit_micros_balance = $1, ' - "credit_micros_reserved = $2 WHERE email = $3", - balance_micros, - reserved_micros, + 'UPDATE "user" SET pages_used = $1, pages_limit = $2 WHERE email = $3', + pages_used, + pages_limit, email, ) finally: @@ -239,39 +237,23 @@ async def _set_user_credit( @pytest.fixture -async def credits(): - """Manipulate the test user's credit wallet (direct DB for setup only). +async def page_limits(): + """Manipulate the test user's page limits (direct DB for setup only). - Force-enables ETL credit billing for the duration of the test (it is off - by default for self-hosted/OSS, which would bypass all gating), and - automatically restores the original balance and billing flag afterwards. - - ``MICROS_PER_PAGE`` is exposed so callers can size balances by page count. + Automatically restores original values after each test. """ - class _Credits: - micros_per_page = app_config.MICROS_PER_PAGE - - async def set(self, *, balance_micros: int, reserved_micros: int = 0) -> None: - await _set_user_credit( - TEST_EMAIL, - balance_micros=balance_micros, - reserved_micros=reserved_micros, + class _PageLimits: + async def set(self, *, pages_used: int, pages_limit: int) -> None: + await _set_user_page_limits( + TEST_EMAIL, pages_used=pages_used, pages_limit=pages_limit ) - def pages(self, n: int) -> int: - return n * app_config.MICROS_PER_PAGE - - original_billing = app_config.ETL_CREDIT_BILLING_ENABLED - app_config.ETL_CREDIT_BILLING_ENABLED = True - original = await _get_user_credit(TEST_EMAIL) - try: - yield _Credits() - finally: - app_config.ETL_CREDIT_BILLING_ENABLED = original_billing - await _set_user_credit( - TEST_EMAIL, balance_micros=original[0], reserved_micros=original[1] - ) + original = await _get_user_page_usage(TEST_EMAIL) + yield _PageLimits() + await _set_user_page_limits( + TEST_EMAIL, pages_used=original[0], pages_limit=original[1] + ) # --------------------------------------------------------------------------- diff --git a/surfsense_backend/tests/integration/document_upload/test_etl_credits.py b/surfsense_backend/tests/integration/document_upload/test_page_limits.py similarity index 67% rename from surfsense_backend/tests/integration/document_upload/test_etl_credits.py rename to surfsense_backend/tests/integration/document_upload/test_page_limits.py index 6a2972598..985fd7128 100644 --- a/surfsense_backend/tests/integration/document_upload/test_etl_credits.py +++ b/surfsense_backend/tests/integration/document_upload/test_page_limits.py @@ -1,14 +1,14 @@ """ -Integration tests for ETL credit enforcement during document upload. +Integration tests for page-limit enforcement during document upload. -These tests manipulate the test user's ``credit_micros_balance`` column -directly in the database (setup only) and then exercise the upload pipeline -to verify that: +These tests manipulate the test user's ``pages_used`` / ``pages_limit`` +columns directly in the database (setup only) and then exercise the upload +pipeline to verify that: - - Uploads are rejected *before* ETL when the wallet can't cover the cost. - - The balance decreases after a successful upload (verified via API). - - An ``insufficient_credits`` notification is created on rejection. - - The balance is not modified when a document fails processing. + - Uploads are rejected *before* ETL when the limit is exhausted. + - ``pages_used`` increases after a successful upload (verified via API). + - A ``page_limit_exceeded`` notification is created on rejection. + - ``pages_used`` is not modified when a document fails processing. All tests reuse the existing small fixtures (``sample.pdf``, ``sample.txt``) so no additional processing time is introduced. @@ -32,37 +32,36 @@ pytestmark = pytest.mark.integration # --------------------------------------------------------------------------- -# Helper: read credit balance through the public API +# Helper: read pages_used through the public API # --------------------------------------------------------------------------- -async def _get_balance(client: httpx.AsyncClient, headers: dict[str, str]) -> int: - """Fetch the current user's credit_micros_balance via the /users/me API.""" +async def _get_pages_used(client: httpx.AsyncClient, headers: dict[str, str]) -> int: + """Fetch the current user's pages_used via the /users/me API.""" resp = await client.get("/users/me", headers=headers) assert resp.status_code == 200, ( f"GET /users/me failed ({resp.status_code}): {resp.text}" ) - return resp.json()["credit_micros_balance"] + return resp.json()["pages_used"] # --------------------------------------------------------------------------- -# Test A: Successful upload decrements the balance +# Test A: Successful upload increments pages_used # --------------------------------------------------------------------------- -class TestBalanceDecrementsOnSuccess: - """After a successful PDF upload the user's balance must shrink.""" +class TestPageUsageIncrementsOnSuccess: + """After a successful PDF upload the user's ``pages_used`` must grow.""" - async def test_balance_decreases_after_pdf_upload( + async def test_pages_used_increases_after_pdf_upload( self, client: httpx.AsyncClient, headers: dict[str, str], search_space_id: int, cleanup_doc_ids: list[int], - credits, + page_limits, ): - await credits.set(balance_micros=credits.pages(1000)) - before = await _get_balance(client, headers) + await page_limits.set(pages_used=0, pages_limit=1000) resp = await upload_file( client, headers, "sample.pdf", search_space_id=search_space_id @@ -77,28 +76,30 @@ class TestBalanceDecrementsOnSuccess: for did in doc_ids: assert statuses[did]["status"]["state"] == "ready" - after = await _get_balance(client, headers) - assert after < before, "balance should have dropped after successful processing" + used = await _get_pages_used(client, headers) + assert used > 0, "pages_used should have increased after successful processing" # --------------------------------------------------------------------------- -# Test B: Upload rejected when the wallet is empty +# Test B: Upload rejected when page limit is fully exhausted # --------------------------------------------------------------------------- -class TestUploadRejectedWhenCreditExhausted: - """When the balance is zero the document should reach ``failed`` status - with an insufficient-credit reason.""" +class TestUploadRejectedWhenLimitExhausted: + """ + When ``pages_used == pages_limit`` (zero remaining) the document + should reach ``failed`` status with a page-limit reason. + """ - async def test_pdf_fails_when_no_credit_remaining( + async def test_pdf_fails_when_no_pages_remaining( self, client: httpx.AsyncClient, headers: dict[str, str], search_space_id: int, cleanup_doc_ids: list[int], - credits, + page_limits, ): - await credits.set(balance_micros=0) + await page_limits.set(pages_used=100, pages_limit=100) resp = await upload_file( client, headers, "sample.pdf", search_space_id=search_space_id @@ -113,19 +114,19 @@ class TestUploadRejectedWhenCreditExhausted: for did in doc_ids: assert statuses[did]["status"]["state"] == "failed" reason = statuses[did]["status"].get("reason", "").lower() - assert "credit" in reason, ( - f"Expected 'credit' in failure reason, got: {reason!r}" + assert "page limit" in reason, ( + f"Expected 'page limit' in failure reason, got: {reason!r}" ) - async def test_balance_unchanged_after_rejection( + async def test_pages_used_unchanged_after_limit_rejection( self, client: httpx.AsyncClient, headers: dict[str, str], search_space_id: int, cleanup_doc_ids: list[int], - credits, + page_limits, ): - await credits.set(balance_micros=0) + await page_limits.set(pages_used=50, pages_limit=50) resp = await upload_file( client, headers, "sample.pdf", search_space_id=search_space_id @@ -138,30 +139,30 @@ class TestUploadRejectedWhenCreditExhausted: client, headers, doc_ids, search_space_id=search_space_id, timeout=300.0 ) - balance = await _get_balance(client, headers) - assert balance == 0, ( - f"balance should remain 0 after rejected upload, got {balance}" + used = await _get_pages_used(client, headers) + assert used == 50, ( + f"pages_used should remain 50 after rejected upload, got {used}" ) # --------------------------------------------------------------------------- -# Test C: Insufficient-credits notification is created on rejection +# Test C: Page-limit notification is created on rejection # --------------------------------------------------------------------------- -class TestInsufficientCreditsNotification: - """An ``insufficient_credits`` notification must be created when upload - is rejected due to an empty wallet.""" +class TestPageLimitNotification: + """A ``page_limit_exceeded`` notification must be created when upload + is rejected due to the limit.""" - async def test_insufficient_credits_notification_created( + async def test_page_limit_exceeded_notification_created( self, client: httpx.AsyncClient, headers: dict[str, str], search_space_id: int, cleanup_doc_ids: list[int], - credits, + page_limits, ): - await credits.set(balance_micros=0) + await page_limits.set(pages_used=100, pages_limit=100) resp = await upload_file( client, headers, "sample.pdf", search_space_id=search_space_id @@ -177,18 +178,19 @@ class TestInsufficientCreditsNotification: notifications = await get_notifications( client, headers, - type_filter="insufficient_credits", + type_filter="page_limit_exceeded", search_space_id=search_space_id, ) assert len(notifications) >= 1, ( - "Expected at least one insufficient_credits notification" + "Expected at least one page_limit_exceeded notification" ) latest = notifications[0] assert ( - "credit" in latest["title"].lower() or "credit" in latest["message"].lower() + "page limit" in latest["title"].lower() + or "page limit" in latest["message"].lower() ), ( - f"Notification should mention credit: title={latest['title']!r}, " + f"Notification should mention page limit: title={latest['title']!r}, " f"message={latest['message']!r}" ) @@ -208,9 +210,9 @@ class TestDocumentProcessingNotification: headers: dict[str, str], search_space_id: int, cleanup_doc_ids: list[int], - credits, + page_limits, ): - await credits.set(balance_micros=credits.pages(1000)) + await page_limits.set(pages_used=0, pages_limit=1000) resp = await upload_file( client, headers, "sample.txt", search_space_id=search_space_id @@ -240,24 +242,23 @@ class TestDocumentProcessingNotification: # --------------------------------------------------------------------------- -# Test E: balance unchanged when a document fails for non-credit reasons +# Test E: pages_used unchanged when a document fails for non-limit reasons # --------------------------------------------------------------------------- -class TestBalanceUnchangedOnProcessingFailure: - """If a document fails during ETL (e.g. empty/corrupt file) rather than a - credit rejection, the balance should remain unchanged.""" +class TestPagesUnchangedOnProcessingFailure: + """If a document fails during ETL (e.g. empty/corrupt file) rather than + a page-limit rejection, ``pages_used`` should remain unchanged.""" - async def test_balance_stable_on_etl_failure( + async def test_pages_used_stable_on_etl_failure( self, client: httpx.AsyncClient, headers: dict[str, str], search_space_id: int, cleanup_doc_ids: list[int], - credits, + page_limits, ): - starting = credits.pages(1000) - await credits.set(balance_micros=starting) + await page_limits.set(pages_used=10, pages_limit=1000) resp = await upload_file( client, headers, "empty.pdf", search_space_id=search_space_id @@ -273,32 +274,28 @@ class TestBalanceUnchangedOnProcessingFailure: for did in doc_ids: assert statuses[did]["status"]["state"] == "failed" - balance = await _get_balance(client, headers) - assert balance == starting, ( - f"balance should remain {starting} after ETL failure, got {balance}" - ) + used = await _get_pages_used(client, headers) + assert used == 10, f"pages_used should remain 10 after ETL failure, got {used}" # --------------------------------------------------------------------------- -# Test F: Second upload rejected after first consumes remaining credit +# Test F: Second upload rejected after first consumes remaining quota # --------------------------------------------------------------------------- -class TestSecondUploadExceedsCredit: - """Upload one PDF successfully, consuming the credit, then verify a second - upload is rejected.""" +class TestSecondUploadExceedsLimit: + """Upload one PDF successfully, consuming the quota, then verify a + second upload is rejected.""" - async def test_second_upload_rejected_after_credit_consumed( + async def test_second_upload_rejected_after_quota_consumed( self, client: httpx.AsyncClient, headers: dict[str, str], search_space_id: int, cleanup_doc_ids: list[int], - credits, + page_limits, ): - # Exactly one page of credit: the first 1-page PDF fits, the second - # is rejected once the wallet hits zero. - await credits.set(balance_micros=credits.pages(1)) + await page_limits.set(pages_used=0, pages_limit=1) resp1 = await upload_file( client, headers, "sample.pdf", search_space_id=search_space_id @@ -330,6 +327,6 @@ class TestSecondUploadExceedsCredit: for did in second_ids: assert statuses2[did]["status"]["state"] == "failed" reason = statuses2[did]["status"].get("reason", "").lower() - assert "credit" in reason, ( - f"Expected 'credit' in failure reason, got: {reason!r}" + assert "page limit" in reason, ( + f"Expected 'page limit' in failure reason, got: {reason!r}" ) diff --git a/surfsense_backend/tests/integration/document_upload/test_stripe_credit_purchases.py b/surfsense_backend/tests/integration/document_upload/test_stripe_page_purchases.py similarity index 76% rename from surfsense_backend/tests/integration/document_upload/test_stripe_credit_purchases.py rename to surfsense_backend/tests/integration/document_upload/test_stripe_page_purchases.py index e1955494d..143c9e252 100644 --- a/surfsense_backend/tests/integration/document_upload/test_stripe_credit_purchases.py +++ b/surfsense_backend/tests/integration/document_upload/test_stripe_page_purchases.py @@ -1,10 +1,3 @@ -"""Integration tests for Stripe credit-pack purchases. - -Buying credit packs tops up ``user.credit_micros_balance``. Legacy page-pack -buying has been removed; these tests exercise the credit checkout session, -webhook fulfillment (idempotent), and the reconciliation fallback. -""" - from __future__ import annotations from types import SimpleNamespace @@ -26,8 +19,6 @@ pytestmark = pytest.mark.integration _ASYNCPG_URL = TEST_DATABASE_URL.replace("postgresql+asyncpg://", "postgresql://") -_CREDIT_MICROS_PER_UNIT = 1_000_000 - async def _execute(query: str, *args) -> None: conn = await asyncpg.connect(_ASYNCPG_URL) @@ -51,12 +42,10 @@ async def _get_user_id(email: str) -> str: return str(row["id"]) -async def _get_balance(email: str) -> int: - row = await _fetchrow( - 'SELECT credit_micros_balance FROM "user" WHERE email = $1', email - ) +async def _get_pages_limit(email: str) -> int: + row = await _fetchrow('SELECT pages_limit FROM "user" WHERE email = $1', email) assert row is not None, f"User {email!r} not found" - return row["credit_micros_balance"] + return row["pages_limit"] def _extract_access_token(response: httpx.Response) -> str | None: @@ -112,23 +101,10 @@ def headers(auth_token: str) -> dict[str, str]: @pytest.fixture(autouse=True) -async def _cleanup_credit_purchases(): - await _execute("DELETE FROM credit_purchases") +async def _cleanup_page_purchases(): + await _execute("DELETE FROM page_purchases") yield - await _execute("DELETE FROM credit_purchases") - - -def _configure_credit_buying(monkeypatch) -> None: - monkeypatch.setattr(stripe_routes.config, "STRIPE_CREDIT_BUYING_ENABLED", True) - monkeypatch.setattr( - stripe_routes.config, "STRIPE_CREDIT_PRICE_ID", "price_credit_1" - ) - monkeypatch.setattr( - stripe_routes.config, "STRIPE_CREDIT_MICROS_PER_UNIT", _CREDIT_MICROS_PER_UNIT - ) - monkeypatch.setattr( - stripe_routes.config, "NEXT_FRONTEND_URL", "http://localhost:3000" - ) + await _execute("DELETE FROM page_purchases") class _FakeCreateStripeClient: @@ -176,19 +152,18 @@ class _FakeReconciliationStripeClient: class TestStripeCheckoutSessionCreation: - async def test_credit_status_reflects_backend_toggle( + async def test_get_status_reflects_backend_toggle( self, client, headers, monkeypatch ): - monkeypatch.setattr(stripe_routes.config, "STRIPE_CREDIT_BUYING_ENABLED", False) - disabled = await client.get("/api/v1/stripe/credit-status", headers=headers) - assert disabled.status_code == 200, disabled.text - assert disabled.json()["credit_buying_enabled"] is False - assert "credit_micros_balance" in disabled.json() + monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGE_BUYING_ENABLED", False) + disabled_response = await client.get("/api/v1/stripe/status", headers=headers) + assert disabled_response.status_code == 200, disabled_response.text + assert disabled_response.json() == {"page_buying_enabled": False} - monkeypatch.setattr(stripe_routes.config, "STRIPE_CREDIT_BUYING_ENABLED", True) - enabled = await client.get("/api/v1/stripe/credit-status", headers=headers) - assert enabled.status_code == 200, enabled.text - assert enabled.json()["credit_buying_enabled"] is True + monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGE_BUYING_ENABLED", True) + enabled_response = await client.get("/api/v1/stripe/status", headers=headers) + assert enabled_response.status_code == 200, enabled_response.text + assert enabled_response.json() == {"page_buying_enabled": True} async def test_create_checkout_session_records_pending_purchase( self, @@ -207,10 +182,14 @@ class TestStripeCheckoutSessionCreation: fake_client = _FakeCreateStripeClient(checkout_session) monkeypatch.setattr(stripe_routes, "get_stripe_client", lambda: fake_client) - _configure_credit_buying(monkeypatch) + monkeypatch.setattr(stripe_routes.config, "STRIPE_PRICE_ID", "price_pages_1000") + monkeypatch.setattr( + stripe_routes.config, "NEXT_FRONTEND_URL", "http://localhost:3000" + ) + monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGES_PER_UNIT", 1000) response = await client.post( - "/api/v1/stripe/create-credit-checkout-session", + "/api/v1/stripe/create-checkout-session", headers=headers, json={"quantity": 2, "search_space_id": search_space_id}, ) @@ -220,7 +199,7 @@ class TestStripeCheckoutSessionCreation: assert fake_client.last_params is not None assert fake_client.last_params["mode"] == "payment" assert fake_client.last_params["line_items"] == [ - {"price": "price_credit_1", "quantity": 2} + {"price": "price_pages_1000", "quantity": 2} ] assert ( fake_client.last_params["success_url"] @@ -231,21 +210,19 @@ class TestStripeCheckoutSessionCreation: fake_client.last_params["cancel_url"] == f"http://localhost:3000/dashboard/{search_space_id}/purchase-cancel" ) - assert fake_client.last_params["metadata"]["purchase_type"] == "credits" purchase = await _fetchrow( """ - SELECT quantity, credit_micros_granted, status, source - FROM credit_purchases + SELECT quantity, pages_granted, status + FROM page_purchases WHERE stripe_checkout_session_id = $1 """, checkout_session.id, ) assert purchase is not None assert purchase["quantity"] == 2 - assert purchase["credit_micros_granted"] == 2 * _CREDIT_MICROS_PER_UNIT + assert purchase["pages_granted"] == 2000 assert purchase["status"] == "PENDING" - assert purchase["source"] == "checkout" async def test_create_checkout_session_returns_503_when_buying_disabled( self, @@ -254,34 +231,34 @@ class TestStripeCheckoutSessionCreation: search_space_id: int, monkeypatch, ): - monkeypatch.setattr(stripe_routes.config, "STRIPE_CREDIT_BUYING_ENABLED", False) + monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGE_BUYING_ENABLED", False) response = await client.post( - "/api/v1/stripe/create-credit-checkout-session", + "/api/v1/stripe/create-checkout-session", headers=headers, json={"quantity": 2, "search_space_id": search_space_id}, ) assert response.status_code == 503, response.text assert ( - response.json()["detail"] == "Credit purchases are temporarily unavailable." + response.json()["detail"] == "Page purchases are temporarily unavailable." ) - count = await _fetchrow("SELECT COUNT(*) AS count FROM credit_purchases") - assert count is not None - assert count["count"] == 0 + purchase_count = await _fetchrow("SELECT COUNT(*) AS count FROM page_purchases") + assert purchase_count is not None + assert purchase_count["count"] == 0 class TestStripeWebhookFulfillment: - async def test_webhook_grants_credit_once( + async def test_webhook_grants_pages_once( self, client, headers, search_space_id: int, - credits, + page_limits, monkeypatch, ): - await credits.set(balance_micros=5_000_000) + await page_limits.set(pages_used=0, pages_limit=100) checkout_session = SimpleNamespace( id="cs_test_webhook_123", @@ -293,16 +270,21 @@ class TestStripeWebhookFulfillment: create_client = _FakeCreateStripeClient(checkout_session) monkeypatch.setattr(stripe_routes, "get_stripe_client", lambda: create_client) - _configure_credit_buying(monkeypatch) + monkeypatch.setattr(stripe_routes.config, "STRIPE_PRICE_ID", "price_pages_1000") + monkeypatch.setattr( + stripe_routes.config, "NEXT_FRONTEND_URL", "http://localhost:3000" + ) + monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGES_PER_UNIT", 1000) create_response = await client.post( - "/api/v1/stripe/create-credit-checkout-session", + "/api/v1/stripe/create-checkout-session", headers=headers, json={"quantity": 3, "search_space_id": search_space_id}, ) assert create_response.status_code == 200, create_response.text - assert await _get_balance(TEST_EMAIL) == 5_000_000 + initial_limit = await _get_pages_limit(TEST_EMAIL) + assert initial_limit == 100 user_id = await _get_user_id(TEST_EMAIL) webhook_checkout_session = SimpleNamespace( @@ -314,8 +296,7 @@ class TestStripeWebhookFulfillment: metadata={ "user_id": user_id, "quantity": "3", - "credit_micros_per_unit": str(_CREDIT_MICROS_PER_UNIT), - "purchase_type": "credits", + "pages_per_unit": "1000", }, ) event = SimpleNamespace( @@ -334,12 +315,13 @@ class TestStripeWebhookFulfillment: ) assert first_response.status_code == 200, first_response.text - assert await _get_balance(TEST_EMAIL) == 5_000_000 + 3 * _CREDIT_MICROS_PER_UNIT + updated_limit = await _get_pages_limit(TEST_EMAIL) + assert updated_limit == 3100 purchase = await _fetchrow( """ SELECT status, amount_total, currency, stripe_payment_intent_id - FROM credit_purchases + FROM page_purchases WHERE stripe_checkout_session_id = $1 """, checkout_session.id, @@ -357,8 +339,7 @@ class TestStripeWebhookFulfillment: ) assert second_response.status_code == 200, second_response.text - # Idempotent: a duplicate webhook does not double-grant. - assert await _get_balance(TEST_EMAIL) == 5_000_000 + 3 * _CREDIT_MICROS_PER_UNIT + assert await _get_pages_limit(TEST_EMAIL) == 3100 class TestStripeReconciliation: @@ -367,10 +348,10 @@ class TestStripeReconciliation: client, headers, search_space_id: int, - credits, + page_limits, monkeypatch, ): - await credits.set(balance_micros=1_000_000) + await page_limits.set(pages_used=220, pages_limit=150) checkout_session = SimpleNamespace( id="cs_test_reconcile_paid_123", @@ -382,15 +363,19 @@ class TestStripeReconciliation: create_client = _FakeCreateStripeClient(checkout_session) monkeypatch.setattr(stripe_routes, "get_stripe_client", lambda: create_client) - _configure_credit_buying(monkeypatch) + monkeypatch.setattr(stripe_routes.config, "STRIPE_PRICE_ID", "price_pages_1000") + monkeypatch.setattr( + stripe_routes.config, "NEXT_FRONTEND_URL", "http://localhost:3000" + ) + monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGES_PER_UNIT", 1000) create_response = await client.post( - "/api/v1/stripe/create-credit-checkout-session", + "/api/v1/stripe/create-checkout-session", headers=headers, json={"quantity": 3, "search_space_id": search_space_id}, ) assert create_response.status_code == 200, create_response.text - assert await _get_balance(TEST_EMAIL) == 1_000_000 + assert await _get_pages_limit(TEST_EMAIL) == 150 reconciled_session = SimpleNamespace( id=checkout_session.id, @@ -417,15 +402,15 @@ class TestStripeReconciliation: 20, ) - await stripe_reconciliation_task._reconcile_pending_credit_purchases() + await stripe_reconciliation_task._reconcile_pending_page_purchases() assert reconcile_client.requested_ids == [checkout_session.id] - assert await _get_balance(TEST_EMAIL) == 1_000_000 + 3 * _CREDIT_MICROS_PER_UNIT + assert await _get_pages_limit(TEST_EMAIL) == 3220 purchase = await _fetchrow( """ SELECT status, amount_total, currency, stripe_payment_intent_id - FROM credit_purchases + FROM page_purchases WHERE stripe_checkout_session_id = $1 """, checkout_session.id, @@ -441,10 +426,10 @@ class TestStripeReconciliation: client, headers, search_space_id: int, - credits, + page_limits, monkeypatch, ): - await credits.set(balance_micros=500_000) + await page_limits.set(pages_used=0, pages_limit=500) checkout_session = SimpleNamespace( id="cs_test_reconcile_expired_123", @@ -456,10 +441,14 @@ class TestStripeReconciliation: create_client = _FakeCreateStripeClient(checkout_session) monkeypatch.setattr(stripe_routes, "get_stripe_client", lambda: create_client) - _configure_credit_buying(monkeypatch) + monkeypatch.setattr(stripe_routes.config, "STRIPE_PRICE_ID", "price_pages_1000") + monkeypatch.setattr( + stripe_routes.config, "NEXT_FRONTEND_URL", "http://localhost:3000" + ) + monkeypatch.setattr(stripe_routes.config, "STRIPE_PAGES_PER_UNIT", 1000) create_response = await client.post( - "/api/v1/stripe/create-credit-checkout-session", + "/api/v1/stripe/create-checkout-session", headers=headers, json={"quantity": 1, "search_space_id": search_space_id}, ) @@ -490,14 +479,14 @@ class TestStripeReconciliation: 20, ) - await stripe_reconciliation_task._reconcile_pending_credit_purchases() + await stripe_reconciliation_task._reconcile_pending_page_purchases() - assert await _get_balance(TEST_EMAIL) == 500_000 + assert await _get_pages_limit(TEST_EMAIL) == 500 purchase = await _fetchrow( """ SELECT status - FROM credit_purchases + FROM page_purchases WHERE stripe_checkout_session_id = $1 """, checkout_session.id, diff --git a/surfsense_backend/tests/integration/indexing_pipeline/test_local_folder_pipeline.py b/surfsense_backend/tests/integration/indexing_pipeline/test_local_folder_pipeline.py index e37c34388..2cd378343 100644 --- a/surfsense_backend/tests/integration/indexing_pipeline/test_local_folder_pipeline.py +++ b/surfsense_backend/tests/integration/indexing_pipeline/test_local_folder_pipeline.py @@ -961,37 +961,24 @@ class TestDirectConvert: # ==================================================================== -# Tier 8: ETL Credits (CR1-CR6) +# Tier 8: Page Limits (PL1-PL6) # ==================================================================== -class TestEtlCredits: - @pytest.fixture(autouse=True) - def _enable_etl_billing(self, monkeypatch): - """Force ETL credit billing on (off by default for self-hosted/OSS).""" - from app.config import config - - monkeypatch.setattr(config, "ETL_CREDIT_BILLING_ENABLED", True) - - @staticmethod - def _micros(pages: int) -> int: - from app.config import config - - return pages * config.MICROS_PER_PAGE - +class TestPageLimits: @pytest.mark.usefixtures(*UNIFIED_FIXTURES) - async def test_cr1_full_scan_debits_balance( + async def test_pl1_full_scan_increments_pages_used( self, db_session: AsyncSession, db_user: User, db_search_space: SearchSpace, tmp_path: Path, ): - """CR1: Successful full-scan sync debits user.credit_micros_balance.""" + """PL1: Successful full-scan sync increments user.pages_used.""" from app.tasks.connector_indexers.local_folder_indexer import index_local_folder - starting = self._micros(500) - db_user.credit_micros_balance = starting + db_user.pages_used = 0 + db_user.pages_limit = 500 await db_session.flush() (tmp_path / "note.md").write_text("# Hello World\n\nContent here.") @@ -1008,22 +995,21 @@ class TestEtlCredits: assert count == 1 await db_session.refresh(db_user) - assert db_user.credit_micros_balance < starting, ( - "balance should drop after indexing" - ) + assert db_user.pages_used > 0, "pages_used should increase after indexing" @pytest.mark.usefixtures(*UNIFIED_FIXTURES) - async def test_cr2_full_scan_blocked_when_credit_exhausted( + async def test_pl2_full_scan_blocked_when_limit_exhausted( self, db_session: AsyncSession, db_user: User, db_search_space: SearchSpace, tmp_path: Path, ): - """CR2: Full-scan skips file when the wallet is empty.""" + """PL2: Full-scan skips file when page limit is exhausted.""" from app.tasks.connector_indexers.local_folder_indexer import index_local_folder - db_user.credit_micros_balance = 0 + db_user.pages_used = 100 + db_user.pages_limit = 100 await db_session.flush() (tmp_path / "note.md").write_text("# Hello World\n\nContent here.") @@ -1039,23 +1025,21 @@ class TestEtlCredits: assert count == 0 await db_session.refresh(db_user) - assert db_user.credit_micros_balance == 0, ( - "balance should not change on rejection" - ) + assert db_user.pages_used == 100, "pages_used should not change on rejection" @pytest.mark.usefixtures(*UNIFIED_FIXTURES) - async def test_cr3_single_file_debits_balance( + async def test_pl3_single_file_increments_pages_used( self, db_session: AsyncSession, db_user: User, db_search_space: SearchSpace, tmp_path: Path, ): - """CR3: Single-file mode debits balance on success.""" + """PL3: Single-file mode increments user.pages_used on success.""" from app.tasks.connector_indexers.local_folder_indexer import index_local_folder - starting = self._micros(500) - db_user.credit_micros_balance = starting + db_user.pages_used = 0 + db_user.pages_limit = 500 await db_session.flush() (tmp_path / "note.md").write_text("# Hello World\n\nContent here.") @@ -1073,22 +1057,21 @@ class TestEtlCredits: assert count == 1 await db_session.refresh(db_user) - assert db_user.credit_micros_balance < starting, ( - "balance should drop after indexing" - ) + assert db_user.pages_used > 0, "pages_used should increase after indexing" @pytest.mark.usefixtures(*UNIFIED_FIXTURES) - async def test_cr4_single_file_blocked_when_credit_exhausted( + async def test_pl4_single_file_blocked_when_limit_exhausted( self, db_session: AsyncSession, db_user: User, db_search_space: SearchSpace, tmp_path: Path, ): - """CR4: Single-file mode skips file when the wallet is empty.""" + """PL4: Single-file mode skips file when page limit is exhausted.""" from app.tasks.connector_indexers.local_folder_indexer import index_local_folder - db_user.credit_micros_balance = 0 + db_user.pages_used = 100 + db_user.pages_limit = 100 await db_session.flush() (tmp_path / "note.md").write_text("# Hello World\n\nContent here.") @@ -1104,25 +1087,24 @@ class TestEtlCredits: assert count == 0 assert err is not None - assert "credit" in err.lower() + assert "page limit" in err.lower() await db_session.refresh(db_user) - assert db_user.credit_micros_balance == 0, ( - "balance should not change on rejection" - ) + assert db_user.pages_used == 100, "pages_used should not change on rejection" @pytest.mark.usefixtures(*UNIFIED_FIXTURES) - async def test_cr5_unchanged_resync_no_extra_debit( + async def test_pl5_unchanged_resync_no_extra_pages( self, db_session: AsyncSession, db_user: User, db_search_space: SearchSpace, tmp_path: Path, ): - """CR5: Re-syncing an unchanged file does not consume additional credit.""" + """PL5: Re-syncing an unchanged file does not consume additional pages.""" from app.tasks.connector_indexers.local_folder_indexer import index_local_folder - db_user.credit_micros_balance = self._micros(500) + db_user.pages_used = 0 + db_user.pages_limit = 500 await db_session.flush() (tmp_path / "note.md").write_text("# Hello\n\nSame content.") @@ -1137,8 +1119,8 @@ class TestEtlCredits: assert count1 == 1 await db_session.refresh(db_user) - balance_after_first = db_user.credit_micros_balance - assert balance_after_first < self._micros(500) + pages_after_first = db_user.pages_used + assert pages_after_first > 0 count2, _, _, _ = await index_local_folder( session=db_session, @@ -1151,12 +1133,12 @@ class TestEtlCredits: assert count2 == 0 await db_session.refresh(db_user) - assert db_user.credit_micros_balance == balance_after_first, ( - "balance should not change for unchanged files" + assert db_user.pages_used == pages_after_first, ( + "pages_used should not increase for unchanged files" ) @pytest.mark.usefixtures(*UNIFIED_FIXTURES) - async def test_cr6_batch_partial_credit_exhaustion( + async def test_pl6_batch_partial_page_limit_exhaustion( self, db_session: AsyncSession, db_user: User, @@ -1164,11 +1146,11 @@ class TestEtlCredits: tmp_path: Path, patched_batch_sessions, ): - """CR6: Batch mode with a tiny balance: some files succeed, rest fail.""" + """PL6: Batch mode with a very low page limit: some files succeed, rest fail.""" from app.tasks.connector_indexers.local_folder_indexer import index_local_folder - # Exactly one page of credit. - db_user.credit_micros_balance = self._micros(1) + db_user.pages_used = 0 + db_user.pages_limit = 1 await db_session.flush() (tmp_path / "a.md").write_text("File A content") @@ -1189,13 +1171,12 @@ class TestEtlCredits: ) assert count >= 1, "at least one file should succeed" - assert failed >= 1, "at least one file should fail due to insufficient credits" + assert failed >= 1, "at least one file should fail due to page limit" assert count + failed == 3 await db_session.refresh(db_user) - # The wallet was drained by the successful file(s); it may dip slightly - # negative when the actual page count exceeds the pre-check estimate. - assert db_user.credit_micros_balance <= 0 + assert db_user.pages_used > 0 + assert db_user.pages_used <= db_user.pages_limit + 1 # ==================================================================== diff --git a/surfsense_backend/tests/integration/notifications/test_insufficient_credits_handler.py b/surfsense_backend/tests/integration/notifications/test_page_limit_handler.py similarity index 52% rename from surfsense_backend/tests/integration/notifications/test_insufficient_credits_handler.py rename to surfsense_backend/tests/integration/notifications/test_page_limit_handler.py index bdfa1b30c..ab89d63c9 100644 --- a/surfsense_backend/tests/integration/notifications/test_insufficient_credits_handler.py +++ b/surfsense_backend/tests/integration/notifications/test_page_limit_handler.py @@ -1,4 +1,4 @@ -"""Behavior guard for the insufficient-credits notification handler.""" +"""Behavior guard for the page-limit notification handler.""" from __future__ import annotations @@ -10,50 +10,52 @@ from app.notifications.service import NotificationService pytestmark = pytest.mark.integration -handler = NotificationService.insufficient_credits +handler = NotificationService.page_limit -async def test_insufficient_credits_message_and_action( +async def test_page_limit_message_and_action( db_session: AsyncSession, db_user: User, db_search_space: SearchSpace ): - """An insufficient-credits notification states cost and carries a buy-credits link.""" - notification = await handler.notify_insufficient_credits( + """A page-limit notification states usage and carries an upgrade action link.""" + notification = await handler.notify_page_limit_exceeded( session=db_session, user_id=db_user.id, document_name="short.pdf", document_type="FILE", search_space_id=db_search_space.id, - balance_micros=250_000, - required_micros=1_000_000, + pages_used=95, + pages_limit=100, + pages_to_add=10, ) - assert notification.type == "insufficient_credits" - assert notification.title == "Insufficient credits: short.pdf" + assert notification.type == "page_limit_exceeded" + assert notification.title == "Page limit exceeded: short.pdf" assert notification.message == ( - "This document costs about $1.00 to process but you have " - "$0.25 of credit left. Add more credits to continue." + "This document has ~10 page(s) but you've used 95/100 pages. " + "Upgrade to process more documents." ) assert notification.notification_metadata["status"] == "failed" - assert notification.notification_metadata["action_label"] == "Buy credits" + assert notification.notification_metadata["action_label"] == "Upgrade Plan" assert notification.notification_metadata["action_url"] == ( - f"/dashboard/{db_search_space.id}/buy-more" + f"/dashboard/{db_search_space.id}/more-pages" ) -async def test_insufficient_credits_truncates_long_name( +async def test_page_limit_truncates_long_name( db_session: AsyncSession, db_user: User, db_search_space: SearchSpace ): """A long document name is truncated in the notification title.""" long_name = "a" * 50 - notification = await handler.notify_insufficient_credits( + notification = await handler.notify_page_limit_exceeded( session=db_session, user_id=db_user.id, document_name=long_name, document_type="FILE", search_space_id=db_search_space.id, - balance_micros=250_000, - required_micros=1_000_000, + pages_used=95, + pages_limit=100, + pages_to_add=10, ) - assert notification.title == f"Insufficient credits: {'a' * 40}..." + assert notification.title == f"Page limit exceeded: {'a' * 40}..." diff --git a/surfsense_backend/tests/integration/podcasts/conftest.py b/surfsense_backend/tests/integration/podcasts/conftest.py deleted file mode 100644 index f244c17d2..000000000 --- a/surfsense_backend/tests/integration/podcasts/conftest.py +++ /dev/null @@ -1,321 +0,0 @@ -"""Podcast API + task integration fixtures. - -The app's DB session and current-user dependencies ride the test's transactional -`db_session`, so seeded rows and rows touched through the endpoints (or the task -bodies) share one transaction that rolls back per test. Only true externals are -faked: the Celery broker (`*_task.delay`) is captured instead of dispatched, the -object store is a tiny in-memory backend, the Celery tasks' own session maker is -bound to the test transaction, and — for the render task — the TTS provider and -the FFmpeg merge are stubbed. `TTS_SERVICE` is pinned so the deterministic brief -proposal can resolve voices. -""" - -from __future__ import annotations - -import contextlib -import uuid -from collections.abc import AsyncGenerator, AsyncIterator -from pathlib import Path - -import httpx -import pytest -import pytest_asyncio -from httpx import ASGITransport -from sqlalchemy.ext.asyncio import AsyncSession - -from app.app import app, limiter -from app.config import config as app_config -from app.db import SearchSpace, User, get_async_session -from app.podcasts.persistence import Podcast, PodcastStatus -from app.podcasts.schemas import ( - DurationTarget, - PodcastSpec, - PodcastStyle, - SpeakerRole, - SpeakerSpec, - Transcript, - TranscriptTurn, -) -from app.podcasts.service import PodcastService -from app.podcasts.tts import SynthesisRequest, SynthesizedAudio, TextToSpeech -from app.routes.search_spaces_routes import create_default_roles_and_membership -from app.users import current_active_user - -pytestmark = pytest.mark.integration - -limiter.enabled = False - - -@pytest_asyncio.fixture -async def client( - db_session: AsyncSession, - db_user: User, -) -> AsyncGenerator[httpx.AsyncClient, None]: - async def override_session() -> AsyncGenerator[AsyncSession, None]: - yield db_session - - async def override_user() -> User: - return db_user - - previous_overrides = app.dependency_overrides.copy() - app.dependency_overrides[get_async_session] = override_session - app.dependency_overrides[current_active_user] = override_user - - try: - async with httpx.AsyncClient( - transport=ASGITransport(app=app), - base_url="http://test", - timeout=30.0, - follow_redirects=False, - ) as test_client: - yield test_client - finally: - app.dependency_overrides.clear() - app.dependency_overrides.update(previous_overrides) - - -@pytest.fixture(autouse=True) -def tts_service(monkeypatch) -> str: - """Pin a provider with language-agnostic voices so brief proposal resolves.""" - service = "openai/tts-1" - monkeypatch.setattr(app_config, "TTS_SERVICE", service) - return service - - -class CapturedTasks: - """Records the args each podcast Celery task was enqueued with.""" - - def __init__(self) -> None: - self.draft: list[tuple] = [] - self.render: list[tuple] = [] - - -@pytest.fixture(autouse=True) -def captured_tasks(monkeypatch) -> CapturedTasks: - """Capture `*_task.delay` instead of hitting the broker (a boundary).""" - captured = CapturedTasks() - from app.podcasts.tasks import draft_transcript_task, render_audio_task - - monkeypatch.setattr( - draft_transcript_task, "delay", lambda *a, **k: captured.draft.append((a, k)) - ) - monkeypatch.setattr( - render_audio_task, "delay", lambda *a, **k: captured.render.append((a, k)) - ) - return captured - - -class FakeStorageBackend: - """In-memory object store standing in for the real audio backend.""" - - backend_name = "memory" - - def __init__(self) -> None: - self.objects: dict[str, bytes] = {} - self.deleted: list[str] = [] - - async def put(self, key: str, data: bytes, content_type: str | None = None) -> None: - self.objects[key] = data - - async def open_stream(self, key: str) -> AsyncIterator[bytes]: - yield self.objects.get(key, b"audio-bytes") - - async def delete(self, key: str) -> None: - self.deleted.append(key) - - -@pytest.fixture -def fake_storage(monkeypatch) -> FakeStorageBackend: - """Route audio storage to an in-memory backend for the stream routes.""" - backend = FakeStorageBackend() - monkeypatch.setattr("app.podcasts.storage.get_storage_backend", lambda: backend) - monkeypatch.setattr("app.file_storage.factory.get_storage_backend", lambda: backend) - return backend - - -@pytest.fixture -def bind_task_session(db_session: AsyncSession, monkeypatch) -> AsyncSession: - """Bind the Celery tasks' own session maker to the test transaction. - - Task bodies open ``get_celery_session_maker()()`` rather than receiving a - session, so this hands them the test's session without closing it on exit; a - task's ``commit()`` then releases a savepoint and the per-test rollback still - cleans up. - """ - - def _make_session(): - @contextlib.asynccontextmanager - async def _ctx() -> AsyncIterator[AsyncSession]: - yield db_session - - return _ctx() - - for module in ( - "app.podcasts.tasks.draft", - "app.podcasts.tasks.render", - "app.podcasts.tasks.runtime", - ): - monkeypatch.setattr(f"{module}.get_celery_session_maker", lambda: _make_session) - return db_session - - -class FakeTextToSpeech(TextToSpeech): - """In-memory TTS provider: every segment yields fixed bytes (the boundary). - - Records each request so tests can assert how often synthesis was paid for. - """ - - def __init__(self) -> None: - self.requests: list[SynthesisRequest] = [] - - @property - def container(self) -> str: - return "mp3" - - async def synthesize(self, request: SynthesisRequest) -> SynthesizedAudio: - self.requests.append(request) - return SynthesizedAudio(data=b"segment-audio", container="mp3") - - -@pytest.fixture -def fake_tts(monkeypatch) -> FakeTextToSpeech: - """Stand in for the configured TTS provider in the render task.""" - provider = FakeTextToSpeech() - monkeypatch.setattr( - "app.podcasts.tasks.render.get_text_to_speech", lambda: provider - ) - return provider - - -@pytest.fixture -def fake_merge(monkeypatch) -> None: - """Stub the FFmpeg merge (an external binary) to emit a fixed MP3.""" - - async def _merge(segment_paths: list[Path], output_path: Path) -> None: - output_path.write_bytes(b"merged-audio") - - monkeypatch.setattr("app.podcasts.rendering.renderer.concat_to_mp3", _merge) - - -def build_spec( - *, - language: str = "en", - voice_ids: tuple[str, str] = ("openai:alloy", "openai:nova"), -) -> PodcastSpec: - """A valid two-speaker brief; tests override only what they assert on.""" - return PodcastSpec( - language=language, - style=PodcastStyle.CONVERSATIONAL, - speakers=[ - SpeakerSpec( - slot=0, name="Host", role=SpeakerRole.HOST, voice_id=voice_ids[0] - ), - SpeakerSpec( - slot=1, name="Guest", role=SpeakerRole.GUEST, voice_id=voice_ids[1] - ), - ], - duration=DurationTarget(min_minutes=10, max_minutes=20), - ) - - -def build_transcript() -> Transcript: - return Transcript( - turns=[ - TranscriptTurn(speaker=0, text="Welcome to the show."), - TranscriptTurn(speaker=1, text="Glad to be here."), - ] - ) - - -@pytest.fixture -def make_podcast(db_session: AsyncSession): - """Create a podcast advanced to a target lifecycle state via the service. - - Setup runs through the same public service the API uses, on the test's - session, so the endpoint under test reads a realistically-built row. - """ - - ladder = [ - PodcastStatus.AWAITING_BRIEF, - PodcastStatus.DRAFTING, - PodcastStatus.RENDERING, - PodcastStatus.READY, - ] - - async def _make( - *, - search_space_id: int, - status: PodcastStatus = PodcastStatus.AWAITING_BRIEF, - title: str = "Test Podcast", - thread_id: int | None = None, - ) -> Podcast: - service = PodcastService(db_session) - podcast = await service.create( - title=title, search_space_id=search_space_id, thread_id=thread_id - ) - if status is PodcastStatus.PENDING: - await db_session.flush() - return podcast - - targets = ladder[: ladder.index(status) + 1] - for target in targets: - if target is PodcastStatus.AWAITING_BRIEF: - await service.attach_brief(podcast, build_spec()) - elif target is PodcastStatus.DRAFTING: - await service.begin_drafting(podcast) - elif target is PodcastStatus.RENDERING: - await service.attach_transcript(podcast, build_transcript()) - elif target is PodcastStatus.READY: - await service.attach_audio( - podcast, - storage_backend="memory", - storage_key="podcasts/audio.mp3", - duration_seconds=123, - ) - await db_session.flush() - return podcast - - return _make - - -@pytest.fixture -def act_as(): - """Switch the authenticated user for subsequent requests on ``client``. - - The ``client`` fixture installs db_user and restores the prior overrides on - teardown, so re-pointing the auth dependency here is undone per test. - """ - - def _act(user: User) -> None: - app.dependency_overrides[current_active_user] = lambda: user - - return _act - - -@pytest_asyncio.fixture -async def db_other_user(db_session: AsyncSession) -> User: - """A second user who is not a member of ``db_search_space``.""" - user = User( - id=uuid.uuid4(), - email="stranger@surfsense.net", - hashed_password="hashed", - is_active=True, - is_superuser=False, - is_verified=True, - ) - db_session.add(user) - await db_session.flush() - return user - - -@pytest_asyncio.fixture -async def foreign_podcast( - db_session: AsyncSession, db_other_user: User, make_podcast -) -> Podcast: - """A podcast in a space owned by the other user, invisible to db_user.""" - space = SearchSpace(name="Stranger Space", user_id=db_other_user.id) - db_session.add(space) - await db_session.flush() - await create_default_roles_and_membership(db_session, space.id, db_other_user.id) - await db_session.flush() - return await make_podcast(search_space_id=space.id, title="Foreign") diff --git a/surfsense_backend/tests/integration/podcasts/test_brief_gate.py b/surfsense_backend/tests/integration/podcasts/test_brief_gate.py deleted file mode 100644 index 46d97172d..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_brief_gate.py +++ /dev/null @@ -1,80 +0,0 @@ -"""The brief review gate: edit the spec, then approve to start drafting. - -Covers what the user can do while ``awaiting_brief`` — edit the brief under -optimistic concurrency and approve it — and the HTTP status codes the service's -guards map to when an edit races or comes too late. -""" - -from __future__ import annotations - -import pytest - -pytestmark = pytest.mark.integration - -BASE = "/api/v1/podcasts" - - -async def _create(client, search_space_id: int) -> dict: - resp = await client.post( - BASE, - json={ - "title": "Episode", - "search_space_id": search_space_id, - "source_content": "Source content.", - }, - ) - assert resp.status_code == 201 - return resp.json() - - -async def test_approve_brief_starts_drafting_and_enqueues_draft( - client, db_search_space, captured_tasks -): - podcast = await _create(client, db_search_space.id) - - resp = await client.post(f"{BASE}/{podcast['id']}/brief/approve") - - assert resp.status_code == 200 - assert resp.json()["status"] == "drafting" - assert captured_tasks.draft == [((podcast["id"], db_search_space.id), {})] - assert captured_tasks.render == [] - - -async def test_update_spec_bumps_version_and_persists(client, db_search_space): - podcast = await _create(client, db_search_space.id) - spec = podcast["spec"] - spec["focus"] = "A sharper angle" - - resp = await client.patch( - f"{BASE}/{podcast['id']}/spec", - json={"spec": spec, "expected_version": podcast["spec_version"]}, - ) - - assert resp.status_code == 200 - body = resp.json() - assert body["spec_version"] == podcast["spec_version"] + 1 - assert body["spec"]["focus"] == "A sharper angle" - assert body["status"] == "awaiting_brief" - - -async def test_update_spec_with_stale_version_conflicts(client, db_search_space): - podcast = await _create(client, db_search_space.id) - - resp = await client.patch( - f"{BASE}/{podcast['id']}/spec", - json={"spec": podcast["spec"], "expected_version": 999}, - ) - - assert resp.status_code == 409 - - -async def test_update_spec_after_approval_is_rejected(client, db_search_space): - podcast = await _create(client, db_search_space.id) - await client.post(f"{BASE}/{podcast['id']}/brief/approve") - - resp = await client.patch( - f"{BASE}/{podcast['id']}/spec", - json={"spec": podcast["spec"], "expected_version": podcast["spec_version"]}, - ) - - assert resp.status_code == 409 diff --git a/surfsense_backend/tests/integration/podcasts/test_cancel.py b/surfsense_backend/tests/integration/podcasts/test_cancel.py deleted file mode 100644 index 4fe4cfc55..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_cancel.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Cancelling a podcast: allowed while in flight, refused once an episode exists. - -Cancellation is the escape hatch for a podcast that has produced nothing yet. -Once a finished episode exists — including during a regeneration, whose audio -survives until a new render commits — cancel is refused (409): reverting the -regeneration is the way back, and no user action may destroy playable audio. -""" - -import pytest - -from app.podcasts.persistence import PodcastStatus - -pytestmark = pytest.mark.integration - -BASE = "/api/v1/podcasts" - - -async def test_cancel_from_a_live_state_succeeds(client, db_search_space, make_podcast): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.AWAITING_BRIEF - ) - - resp = await client.post(f"{BASE}/{podcast.id}/cancel") - - assert resp.status_code == 200 - assert resp.json()["status"] == "cancelled" - - -async def test_cancel_from_a_terminal_state_conflicts( - client, db_search_space, make_podcast -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - - resp = await client.post(f"{BASE}/{podcast.id}/cancel") - - assert resp.status_code == 409 - - -async def test_cancel_of_a_regeneration_is_rejected( - client, db_search_space, make_podcast -): - # Cancelling here would destroy a playable episode; reverting the - # regeneration is the way back. - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - await client.post(f"{BASE}/{podcast.id}/transcript/regenerate") - - resp = await client.post(f"{BASE}/{podcast.id}/cancel") - - assert resp.status_code == 409 - # The regeneration is still revertable afterwards. - follow_up = await client.post(f"{BASE}/{podcast.id}/regenerate/revert") - assert follow_up.status_code == 200 - assert follow_up.json()["status"] == "ready" diff --git a/surfsense_backend/tests/integration/podcasts/test_create.py b/surfsense_backend/tests/integration/podcasts/test_create.py deleted file mode 100644 index 19b5aeca2..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_create.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Creating a podcast proposes a brief and opens the review gate. - -Driven through the real POST endpoint (auth + DB on one transaction): the row is -created, a brief is proposed inline from defaults, and the podcast lands in -``awaiting_brief`` with a complete spec and nothing generated yet. -""" - -from __future__ import annotations - -import pytest - -pytestmark = pytest.mark.integration - -BASE = "/api/v1/podcasts" - - -async def test_create_proposes_brief_and_opens_gate(client, db_search_space): - resp = await client.post( - BASE, - json={ - "title": "My Episode", - "search_space_id": db_search_space.id, - "source_content": "A long piece of source content about a topic.", - }, - ) - - assert resp.status_code == 201 - body = resp.json() - assert body["title"] == "My Episode" - assert body["status"] == "awaiting_brief" - assert body["spec_version"] == 1 - assert body["spec"] is not None - assert body["spec"]["language"] == "en" - assert len(body["spec"]["speakers"]) == 2 - assert body["transcript"] is None - assert body["has_audio"] is False - - -async def test_create_honors_requested_speaker_count(client, db_search_space): - resp = await client.post( - BASE, - json={ - "title": "Solo", - "search_space_id": db_search_space.id, - "source_content": "Content.", - "speaker_count": 3, - }, - ) - - assert resp.status_code == 201 - assert len(resp.json()["spec"]["speakers"]) == 3 diff --git a/surfsense_backend/tests/integration/podcasts/test_draft_task.py b/surfsense_backend/tests/integration/podcasts/test_draft_task.py deleted file mode 100644 index 7dadfc2f5..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_draft_task.py +++ /dev/null @@ -1,117 +0,0 @@ -"""The transcript-drafting task against a real database. - -Drafting is the expensive LLM step, so it runs under ``billable_call``. The -behavior that protects users' money: when billing succeeds, the drafted -transcript is stored and rendering starts immediately (DRAFTING -> RENDERING, -render task enqueued — the brief gate is the only approval); when billing denies -or settlement fails, the podcast ends FAILED with no transcript left behind. The -DB, service, and transcript persistence run for real; only the true externals -are faked — billing (the metering boundary) and the generation graph (the LLM). -""" - -from __future__ import annotations - -from contextlib import asynccontextmanager -from types import SimpleNamespace -from uuid import uuid4 - -import pytest - -from app.podcasts.persistence import PodcastStatus -from app.podcasts.service import read_transcript -from app.podcasts.tasks import draft -from app.services.billable_calls import ( - BillingSettlementError, - QuotaInsufficientError, -) - -from .conftest import build_transcript - -pytestmark = pytest.mark.integration - - -def _wire_billing(monkeypatch, *, billable_call, transcript=None) -> None: - """Replace the billing + LLM externals the draft body reaches for.""" - - async def _resolver(_session, _search_space_id, *, thread_id=None): - return uuid4(), "free", "openrouter/model" - - async def _ainvoke(_state, config=None): - return {"transcript": transcript} - - monkeypatch.setattr(draft, "_resolve_agent_billing_for_search_space", _resolver) - monkeypatch.setattr(draft, "billable_call", billable_call) - monkeypatch.setattr(draft, "transcript_graph", SimpleNamespace(ainvoke=_ainvoke)) - - -async def test_successful_draft_stores_transcript_and_starts_rendering( - monkeypatch, db_search_space, make_podcast, bind_task_session, captured_tasks -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.DRAFTING - ) - - @asynccontextmanager - async def _ok(**_kwargs): - yield SimpleNamespace() - - _wire_billing(monkeypatch, billable_call=_ok, transcript=build_transcript()) - - result = await draft._draft_transcript(podcast.id, db_search_space.id) - - assert result["status"] == "rendering" - assert podcast.status == PodcastStatus.RENDERING - assert read_transcript(podcast) is not None - assert captured_tasks.render == [((podcast.id,), {})] - - -async def test_quota_denial_fails_the_podcast_without_a_transcript( - monkeypatch, db_search_space, make_podcast, bind_task_session -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.DRAFTING - ) - - @asynccontextmanager - async def _deny(**_kwargs): - raise QuotaInsufficientError( - usage_type="podcast_generation", - used_micros=5_000_000, - limit_micros=5_000_000, - remaining_micros=0, - ) - yield # pragma: no cover - unreachable, satisfies the CM protocol - - _wire_billing(monkeypatch, billable_call=_deny) - - result = await draft._draft_transcript(podcast.id, db_search_space.id) - - assert result["reason"] == "quota" - assert podcast.status == PodcastStatus.FAILED - assert read_transcript(podcast) is None - - -async def test_billing_settlement_failure_fails_the_podcast( - monkeypatch, db_search_space, make_podcast, bind_task_session -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.DRAFTING - ) - - @asynccontextmanager - async def _settlement_fails(**_kwargs): - yield SimpleNamespace() - raise BillingSettlementError( - usage_type="podcast_generation", - user_id=uuid4(), - cause=RuntimeError("finalize failed"), - ) - - _wire_billing( - monkeypatch, billable_call=_settlement_fails, transcript=build_transcript() - ) - - result = await draft._draft_transcript(podcast.id, db_search_space.id) - - assert result["reason"] == "billing" - assert podcast.status == PodcastStatus.FAILED diff --git a/surfsense_backend/tests/integration/podcasts/test_public_stream.py b/surfsense_backend/tests/integration/podcasts/test_public_stream.py deleted file mode 100644 index d2ba1d1b9..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_public_stream.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Public (unauthenticated) podcast streaming from a chat snapshot. - -A shared chat snapshot carries each podcast's stored-audio key; the public route -streams those bytes from the object store via ``share_token`` with no auth. A -podcast that isn't in the snapshot is a 404. -""" - -import pytest - -from app.db import NewChatThread, PublicChatSnapshot, User - -pytestmark = pytest.mark.integration - - -async def _snapshot(db_session, *, search_space_id, user: User, token: str, podcasts): - thread = NewChatThread( - title="Shared", search_space_id=search_space_id, created_by_id=user.id - ) - db_session.add(thread) - await db_session.flush() - snapshot = PublicChatSnapshot( - thread_id=thread.id, - share_token=token, - content_hash=f"hash-{token}", - message_ids=[], - snapshot_data={"podcasts": podcasts}, - ) - db_session.add(snapshot) - await db_session.flush() - - -async def test_public_stream_serves_audio_via_storage_key( - client, db_session, db_search_space, db_user, fake_storage -): - await _snapshot( - db_session, - search_space_id=db_search_space.id, - user=db_user, - token="tok-audio", - podcasts=[{"original_id": 555, "storage_key": "podcasts/x.mp3"}], - ) - fake_storage.objects["podcasts/x.mp3"] = b"public-audio" - - resp = await client.get("/api/v1/public/tok-audio/podcasts/555/stream") - - assert resp.status_code == 200 - assert resp.headers["content-type"] == "audio/mpeg" - assert resp.content == b"public-audio" - - -async def test_public_stream_404_when_podcast_absent_from_snapshot( - client, db_session, db_search_space, db_user -): - await _snapshot( - db_session, - search_space_id=db_search_space.id, - user=db_user, - token="tok-empty", - podcasts=[], - ) - - resp = await client.get("/api/v1/public/tok-empty/podcasts/999/stream") - - assert resp.status_code == 404 diff --git a/surfsense_backend/tests/integration/podcasts/test_regeneration.py b/surfsense_backend/tests/integration/podcasts/test_regeneration.py deleted file mode 100644 index fd31df4ca..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_regeneration.py +++ /dev/null @@ -1,202 +0,0 @@ -"""Regeneration: the listen-then-redo loop after the brief gate. - -A user who dislikes the finished audio sends the episode back to the brief -gate: the saved brief reopens for tweaks (voices, length, focus) and drafting -only restarts on a fresh approval. The whole redo can also be reverted at any -point before the new render commits, falling back to the still-stored episode. -These pin the READY -> AWAITING_BRIEF -> DRAFTING round trip, the revert -fallback, and the 409s for acting from states that have nothing to redo or -revert. -""" - -from __future__ import annotations - -import pytest - -from app.podcasts.persistence import Podcast, PodcastStatus -from app.podcasts.service import PodcastService - -from .conftest import build_transcript - -pytestmark = pytest.mark.integration - -BASE = "/api/v1/podcasts" - - -async def test_regenerate_from_ready_reopens_the_brief_gate( - client, db_search_space, make_podcast, captured_tasks -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - - resp = await client.post(f"{BASE}/{podcast.id}/transcript/regenerate") - - assert resp.status_code == 200 - body = resp.json() - assert body["status"] == "awaiting_brief" - # The prior brief is kept as the starting point for the new take. - assert body["spec"] is not None - # Nothing drafts until the user approves the reopened brief. - assert captured_tasks.draft == [] - assert captured_tasks.render == [] - - -async def test_approving_the_reopened_brief_starts_a_fresh_draft( - client, db_search_space, make_podcast, captured_tasks -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - await client.post(f"{BASE}/{podcast.id}/transcript/regenerate") - - resp = await client.post(f"{BASE}/{podcast.id}/brief/approve") - - assert resp.status_code == 200 - assert resp.json()["status"] == "drafting" - assert captured_tasks.draft == [((podcast.id, db_search_space.id), {})] - - -async def test_regenerate_from_brief_gate_is_rejected( - client, db_search_space, make_podcast, captured_tasks -): - # Nothing has been drafted yet, so there is nothing to regenerate. - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.AWAITING_BRIEF - ) - - resp = await client.post(f"{BASE}/{podcast.id}/transcript/regenerate") - - assert resp.status_code == 409 - assert captured_tasks.draft == [] - - -async def test_regenerate_from_cancelled_is_rejected( - client, db_search_space, make_podcast, captured_tasks -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.AWAITING_BRIEF - ) - await client.post(f"{BASE}/{podcast.id}/cancel") - - resp = await client.post(f"{BASE}/{podcast.id}/transcript/regenerate") - - assert resp.status_code == 409 - assert captured_tasks.draft == [] - - -async def test_reverting_a_regeneration_restores_the_ready_episode( - client, db_search_space, make_podcast, captured_tasks -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - await client.post(f"{BASE}/{podcast.id}/transcript/regenerate") - - resp = await client.post(f"{BASE}/{podcast.id}/regenerate/revert") - - assert resp.status_code == 200 - body = resp.json() - assert body["status"] == "ready" - # The episode the user could already play is untouched. - assert body["has_audio"] is True - assert captured_tasks.draft == [] - assert captured_tasks.render == [] - - -async def test_reverting_mid_draft_keeps_the_episode( - client, db_search_space, make_podcast -): - # Changing one's mind is allowed even after the reopened brief was - # approved: the episode survives until a new render replaces it. - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - await client.post(f"{BASE}/{podcast.id}/transcript/regenerate") - await client.post(f"{BASE}/{podcast.id}/brief/approve") - - resp = await client.post(f"{BASE}/{podcast.id}/regenerate/revert") - - assert resp.status_code == 200 - assert resp.json()["status"] == "ready" - - -async def test_reverting_mid_render_keeps_the_episode( - client, db_session, db_search_space, make_podcast -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - service = PodcastService(db_session) - await service.regenerate(podcast) - await service.begin_drafting(podcast) - await service.attach_transcript(podcast, build_transcript()) - - resp = await client.post(f"{BASE}/{podcast.id}/regenerate/revert") - - assert resp.status_code == 200 - assert resp.json()["status"] == "ready" - - -async def test_reverted_episode_can_be_regenerated_again( - client, db_search_space, make_podcast -): - # Reverting must not strand the episode: the user can change their mind - # again immediately. - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - await client.post(f"{BASE}/{podcast.id}/transcript/regenerate") - await client.post(f"{BASE}/{podcast.id}/regenerate/revert") - - resp = await client.post(f"{BASE}/{podcast.id}/transcript/regenerate") - - assert resp.status_code == 200 - assert resp.json()["status"] == "awaiting_brief" - - -async def test_revert_on_a_fresh_brief_gate_is_rejected( - client, db_search_space, make_podcast -): - # A first-time brief has no regeneration to revert. - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.AWAITING_BRIEF - ) - - resp = await client.post(f"{BASE}/{podcast.id}/regenerate/revert") - - assert resp.status_code == 409 - assert resp.json()["detail"] - - -async def test_revert_when_nothing_was_regenerated_is_rejected( - client, db_search_space, make_podcast -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - - resp = await client.post(f"{BASE}/{podcast.id}/regenerate/revert") - - assert resp.status_code == 409 - - -async def test_regenerate_without_a_brief_is_rejected( - client, db_session, db_search_space, captured_tasks -): - # Legacy episodes finished before briefs existed; reopening a gate with - # nothing to review would strand them there. - podcast = Podcast( - title="Legacy Episode", - search_space_id=db_search_space.id, - status=PodcastStatus.READY, - spec_version=1, - file_location="/var/old/podcast.mp3", - ) - db_session.add(podcast) - await db_session.flush() - - resp = await client.post(f"{BASE}/{podcast.id}/transcript/regenerate") - - assert resp.status_code == 422 - assert captured_tasks.draft == [] diff --git a/surfsense_backend/tests/integration/podcasts/test_render_task.py b/surfsense_backend/tests/integration/podcasts/test_render_task.py deleted file mode 100644 index 5a97a00c7..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_render_task.py +++ /dev/null @@ -1,100 +0,0 @@ -"""The audio-rendering task against a real database. - -From RENDERING, the task synthesises and merges the approved transcript, stores -the bytes, and marks the podcast READY with the storage location recorded. The -DB, service, renderer orchestration, and storage wrapper run for real; the true -externals are faked — the TTS provider, the FFmpeg merge, and the object store. -""" - -from __future__ import annotations - -import pytest - -from app.podcasts.persistence import PodcastStatus -from app.podcasts.service import PodcastService -from app.podcasts.tasks import render - -from .conftest import build_transcript - -pytestmark = pytest.mark.integration - - -async def test_render_marks_ready_and_stores_audio( - db_search_space, make_podcast, bind_task_session, fake_tts, fake_merge, fake_storage -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.RENDERING - ) - - result = await render._render_audio(podcast.id) - - assert result["status"] == "ready" - assert podcast.status == PodcastStatus.READY - assert podcast.storage_backend == "memory" - assert podcast.storage_key - assert fake_storage.objects[podcast.storage_key] == b"merged-audio" - - -async def test_rerender_replaces_audio_and_purges_the_old_object( - db_session, - db_search_space, - make_podcast, - bind_task_session, - fake_tts, - fake_merge, - fake_storage, -): - # A regenerated episode keeps exactly one stored object: the new render - # must not leak the superseded audio in the object store. - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - old_key = podcast.storage_key - fake_storage.objects[old_key] = b"old-audio" - - service = PodcastService(db_session) - await service.regenerate(podcast) - await service.begin_drafting(podcast) - await service.attach_transcript(podcast, build_transcript()) - - result = await render._render_audio(podcast.id) - - assert result["status"] == "ready" - assert podcast.status == PodcastStatus.READY - assert podcast.storage_key != old_key - assert fake_storage.objects[podcast.storage_key] == b"merged-audio" - assert old_key in fake_storage.deleted - - -async def test_render_losing_to_a_user_revert_keeps_the_episode_and_leaks_nothing( - db_session, - db_search_space, - make_podcast, - bind_task_session, - fake_tts, - fake_merge, - fake_storage, -): - # The user reverts the regeneration while the render is in flight: the - # stale render must neither resurrect the redo nor leak the object it - # already stored. - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - old_key = podcast.storage_key - fake_storage.objects[old_key] = b"old-audio" - - service = PodcastService(db_session) - await service.regenerate(podcast) - await service.begin_drafting(podcast) - await service.attach_transcript(podcast, build_transcript()) - await service.revert_regeneration(podcast) - - result = await render._render_audio(podcast.id) - - assert result["status"] == "superseded" - assert podcast.status == PodcastStatus.READY - assert podcast.storage_key == old_key - assert old_key not in fake_storage.deleted - stale_keys = [key for key in fake_storage.objects if key != old_key] - assert all(key in fake_storage.deleted for key in stale_keys) diff --git a/surfsense_backend/tests/integration/podcasts/test_scoping.py b/surfsense_backend/tests/integration/podcasts/test_scoping.py deleted file mode 100644 index 304af6b6e..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_scoping.py +++ /dev/null @@ -1,53 +0,0 @@ -"""Podcasts are scoped to search-space membership. - -A user can only create or read podcasts in spaces they belong to, and an -unscoped listing returns only the caller's own podcasts — never another -member's. -""" - -import pytest - -pytestmark = pytest.mark.integration - -BASE = "/api/v1/podcasts" - - -async def test_reading_a_podcast_in_a_nonmember_space_is_forbidden( - client, db_search_space, make_podcast, act_as, db_other_user -): - podcast = await make_podcast(search_space_id=db_search_space.id) - act_as(db_other_user) - - resp = await client.get(f"{BASE}/{podcast.id}") - - assert resp.status_code == 403 - - -async def test_creating_in_a_nonmember_space_is_forbidden( - client, db_search_space, act_as, db_other_user -): - act_as(db_other_user) - - resp = await client.post( - BASE, - json={ - "title": "X", - "search_space_id": db_search_space.id, - "source_content": "content", - }, - ) - - assert resp.status_code == 403 - - -async def test_listing_returns_only_the_callers_podcasts( - client, db_search_space, make_podcast, foreign_podcast -): - mine = await make_podcast(search_space_id=db_search_space.id, title="Mine") - - resp = await client.get(BASE) - - assert resp.status_code == 200 - ids = {p["id"] for p in resp.json()} - assert mine.id in ids - assert foreign_podcast.id not in ids diff --git a/surfsense_backend/tests/integration/podcasts/test_streaming.py b/surfsense_backend/tests/integration/podcasts/test_streaming.py deleted file mode 100644 index 82456bac9..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_streaming.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Streaming a podcast's rendered audio over HTTP. - -A ready podcast streams its bytes from the storage backend; a podcast with no -stored audio returns 404. Storage is an in-memory backend (the object store is a -system boundary). -""" - -from __future__ import annotations - -import pytest - -from app.podcasts.persistence import PodcastStatus - -pytestmark = pytest.mark.integration - -BASE = "/api/v1/podcasts" - - -async def test_stream_serves_stored_audio( - client, db_search_space, make_podcast, fake_storage -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - fake_storage.objects["podcasts/audio.mp3"] = b"the-audio" - - resp = await client.get(f"{BASE}/{podcast.id}/stream") - - assert resp.status_code == 200 - assert resp.headers["content-type"] == "audio/mpeg" - assert resp.content == b"the-audio" - - -async def test_stream_404_when_no_audio(client, db_search_space, make_podcast): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.DRAFTING - ) - - resp = await client.get(f"{BASE}/{podcast.id}/stream") - - assert resp.status_code == 404 diff --git a/surfsense_backend/tests/integration/podcasts/test_task_failure.py b/surfsense_backend/tests/integration/podcasts/test_task_failure.py deleted file mode 100644 index 43212f58f..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_task_failure.py +++ /dev/null @@ -1,45 +0,0 @@ -"""The task failure safety net (``mark_failed``) against a real database. - -When a task body raises, ``mark_failed`` records the reason on the row. Its -contract has two halves worth securing: a still-running podcast moves to FAILED -with the reason, while one that already reached a terminal state is left exactly -as it was rather than forced. A missing row is a no-op, never a crash. -""" - -from __future__ import annotations - -import pytest - -from app.podcasts.persistence import PodcastStatus -from app.podcasts.tasks import runtime - -pytestmark = pytest.mark.integration - - -async def test_marking_failed_records_the_reason_on_a_running_podcast( - db_search_space, make_podcast, bind_task_session -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.DRAFTING - ) - - await runtime.mark_failed(podcast.id, "tts provider unavailable") - - assert podcast.status == PodcastStatus.FAILED - assert podcast.error == "tts provider unavailable" - - -async def test_marking_failed_leaves_an_already_terminal_podcast_untouched( - db_search_space, make_podcast, bind_task_session -): - podcast = await make_podcast( - search_space_id=db_search_space.id, status=PodcastStatus.READY - ) - - await runtime.mark_failed(podcast.id, "too late") - - assert podcast.status == PodcastStatus.READY - - -async def test_marking_a_missing_podcast_failed_is_a_no_op(bind_task_session): - await runtime.mark_failed(987654321, "gone") # must not raise diff --git a/surfsense_backend/tests/integration/podcasts/test_voice_preview.py b/surfsense_backend/tests/integration/podcasts/test_voice_preview.py deleted file mode 100644 index 113172bee..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_voice_preview.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Audible voice previews for the brief gate's voice picker. - -A user choosing voices should hear them, not guess from names. The endpoint -synthesises a short sample for a catalog voice and caches it on disk so each -voice is paid for at most once per process lifetime. Unknown voices and voices -of an inactive provider are 404; no configured TTS is 503. -""" - -from __future__ import annotations - -import pytest - -from app.config import config as app_config - -from .conftest import FakeTextToSpeech - -pytestmark = pytest.mark.integration - -BASE = "/api/v1/podcasts" - - -@pytest.fixture -def preview_tts(monkeypatch, tmp_path) -> FakeTextToSpeech: - """Route preview synthesis to the fake provider and an isolated cache.""" - provider = FakeTextToSpeech() - monkeypatch.setattr("app.podcasts.api.routes.get_text_to_speech", lambda: provider) - monkeypatch.setattr("app.podcasts.voices.preview.PREVIEW_CACHE_ROOT", tmp_path) - return provider - - -async def test_preview_returns_playable_audio_for_a_catalog_voice(client, preview_tts): - resp = await client.get(f"{BASE}/voices/openai:alloy/preview") - - assert resp.status_code == 200 - assert resp.headers["content-type"] == "audio/mpeg" - assert resp.content == b"segment-audio" - - -async def test_preview_is_synthesised_once_then_served_from_cache(client, preview_tts): - first = await client.get(f"{BASE}/voices/openai:alloy/preview") - second = await client.get(f"{BASE}/voices/openai:alloy/preview") - - assert first.status_code == second.status_code == 200 - assert second.content == first.content - assert len(preview_tts.requests) == 1 - - -async def test_preview_unknown_voice_is_404(client, preview_tts): - resp = await client.get(f"{BASE}/voices/openai:nope/preview") - - assert resp.status_code == 404 - assert preview_tts.requests == [] - - -async def test_preview_voice_of_inactive_provider_is_404(client, preview_tts): - # The active provider is OpenAI (pinned in conftest); a Kokoro voice exists - # in the catalog but cannot be heard through the configured provider. - resp = await client.get(f"{BASE}/voices/kokoro:af_heart/preview") - - assert resp.status_code == 404 - assert preview_tts.requests == [] - - -async def test_preview_without_tts_provider_is_503(client, preview_tts, monkeypatch): - monkeypatch.setattr(app_config, "TTS_SERVICE", None) - - resp = await client.get(f"{BASE}/voices/openai:alloy/preview") - - assert resp.status_code == 503 diff --git a/surfsense_backend/tests/integration/podcasts/test_voices.py b/surfsense_backend/tests/integration/podcasts/test_voices.py deleted file mode 100644 index 688ddad56..000000000 --- a/surfsense_backend/tests/integration/podcasts/test_voices.py +++ /dev/null @@ -1,31 +0,0 @@ -"""GET /podcasts/voices: the active provider's catalog, or 503 if unconfigured. - -The brief UI needs the voices the configured TTS provider offers; with no -provider configured there is nothing to choose from, which is a 503 rather than -an empty list. -""" - -import pytest - -from app.config import config as app_config - -pytestmark = pytest.mark.integration - -BASE = "/api/v1/podcasts" - - -async def test_voices_returns_the_active_providers_catalog(client): - resp = await client.get(f"{BASE}/voices") - - assert resp.status_code == 200 - voices = resp.json() - assert voices # openai/tts-1 offers voices - assert {"voice_id", "display_name", "language", "gender"} <= voices[0].keys() - - -async def test_voices_503_when_no_tts_configured(client, monkeypatch): - monkeypatch.setattr(app_config, "TTS_SERVICE", "") - - resp = await client.get(f"{BASE}/voices") - - assert resp.status_code == 503 diff --git a/surfsense_backend/tests/unit/connector_indexers/test_dropbox_parallel.py b/surfsense_backend/tests/unit/connector_indexers/test_dropbox_parallel.py index a74591169..b87d1be42 100644 --- a/surfsense_backend/tests/unit/connector_indexers/test_dropbox_parallel.py +++ b/surfsense_backend/tests/unit/connector_indexers/test_dropbox_parallel.py @@ -272,26 +272,22 @@ def full_scan_mocks(mock_dropbox_client, monkeypatch): download_and_index_mock = AsyncMock(return_value=(0, 0)) monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock) - from app.services.etl_credit_service import EtlCreditService as _RealECS + from app.services.page_limit_service import PageLimitService as _RealPLS - # get_available_micros -> None means "unlimited" (billing disabled), so no - # batch is gated and charge_credits is a no-op — matching the prior - # 999_999 page-limit intent for these parallel-processing tests. - mock_credit_instance = MagicMock() - mock_credit_instance.get_available_micros = AsyncMock(return_value=None) - mock_credit_instance.charge_credits = AsyncMock(return_value=None) + mock_page_limit_instance = MagicMock() + mock_page_limit_instance.get_page_usage = AsyncMock(return_value=(0, 999_999)) + mock_page_limit_instance.update_page_usage = AsyncMock() - class _MockEtlCreditService: + class _MockPageLimitService: estimate_pages_from_metadata = staticmethod( - _RealECS.estimate_pages_from_metadata + _RealPLS.estimate_pages_from_metadata ) - pages_to_micros = staticmethod(_RealECS.pages_to_micros) def __init__(self, session): - self.get_available_micros = mock_credit_instance.get_available_micros - self.charge_credits = mock_credit_instance.charge_credits + self.get_page_usage = mock_page_limit_instance.get_page_usage + self.update_page_usage = mock_page_limit_instance.update_page_usage - monkeypatch.setattr(_mod, "EtlCreditService", _MockEtlCreditService) + monkeypatch.setattr(_mod, "PageLimitService", _MockPageLimitService) return { "dropbox_client": mock_dropbox_client, @@ -397,23 +393,22 @@ def selected_files_mocks(mock_dropbox_client, monkeypatch): download_and_index_mock = AsyncMock(return_value=(0, 0)) monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock) - from app.services.etl_credit_service import EtlCreditService as _RealECS + from app.services.page_limit_service import PageLimitService as _RealPLS - mock_credit_instance = MagicMock() - mock_credit_instance.get_available_micros = AsyncMock(return_value=None) - mock_credit_instance.charge_credits = AsyncMock(return_value=None) + mock_page_limit_instance = MagicMock() + mock_page_limit_instance.get_page_usage = AsyncMock(return_value=(0, 999_999)) + mock_page_limit_instance.update_page_usage = AsyncMock() - class _MockEtlCreditService: + class _MockPageLimitService: estimate_pages_from_metadata = staticmethod( - _RealECS.estimate_pages_from_metadata + _RealPLS.estimate_pages_from_metadata ) - pages_to_micros = staticmethod(_RealECS.pages_to_micros) def __init__(self, session): - self.get_available_micros = mock_credit_instance.get_available_micros - self.charge_credits = mock_credit_instance.charge_credits + self.get_page_usage = mock_page_limit_instance.get_page_usage + self.update_page_usage = mock_page_limit_instance.update_page_usage - monkeypatch.setattr(_mod, "EtlCreditService", _MockEtlCreditService) + monkeypatch.setattr(_mod, "PageLimitService", _MockPageLimitService) return { "dropbox_client": mock_dropbox_client, diff --git a/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py b/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py index 4f61976a6..9a13e4525 100644 --- a/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py +++ b/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py @@ -242,28 +242,20 @@ def _folder_dict(file_id: str, name: str) -> dict: } -def _make_page_limit_session(balance_micros=999_999_000, reserved_micros=0): - """Build a mock DB session that real EtlCreditService can operate against. - - ETL credit billing is disabled by default in tests, so get_available_micros - short-circuits to None ("unlimited") and these fields are unused; they're - provided for parity if a test opts into billing. - """ +def _make_page_limit_session(pages_used=0, pages_limit=999_999): + """Build a mock DB session that real PageLimitService can operate against.""" class _FakeUser: - def __init__(self, balance, reserved): - self.credit_micros_balance = balance - self.credit_micros_reserved = reserved + def __init__(self, pu, pl): + self.pages_used = pu + self.pages_limit = pl - fake_user = _FakeUser(balance_micros, reserved_micros) + fake_user = _FakeUser(pages_used, pages_limit) session = AsyncMock() def _make_result(*_a, **_kw): r = MagicMock() - r.first.return_value = ( - fake_user.credit_micros_balance, - fake_user.credit_micros_reserved, - ) + r.first.return_value = (fake_user.pages_used, fake_user.pages_limit) r.unique.return_value.scalar_one_or_none.return_value = fake_user return r diff --git a/surfsense_backend/tests/unit/connector_indexers/test_etl_credits.py b/surfsense_backend/tests/unit/connector_indexers/test_page_limits.py similarity index 73% rename from surfsense_backend/tests/unit/connector_indexers/test_etl_credits.py rename to surfsense_backend/tests/unit/connector_indexers/test_page_limits.py index aca811ee9..66722ffd7 100644 --- a/surfsense_backend/tests/unit/connector_indexers/test_etl_credits.py +++ b/surfsense_backend/tests/unit/connector_indexers/test_page_limits.py @@ -1,22 +1,17 @@ -"""Tests for ETL credit enforcement in connector indexers. +"""Tests for page limit enforcement in connector indexers. Covers: - A) EtlCreditService.estimate_pages_from_metadata — pure function (no mocks) - B) Credit-wallet gating in the connector indexers, tested through the real - EtlCreditService with a mock DB session (system boundary). ETL credit - billing is force-enabled per-test so the gating path is exercised. + A) PageLimitService.estimate_pages_from_metadata — pure function (no mocks) + B) Page-limit quota gating in _index_selected_files tested through the + real PageLimitService with a mock DB session (system boundary). Google Drive is the primary, with OneDrive/Dropbox smoke tests. - -Page estimates are converted to micro-USD at ``config.MICROS_PER_PAGE`` per -page and debited from ``user.credit_micros_balance``. """ from unittest.mock import AsyncMock, MagicMock import pytest -from app.config import config -from app.services.etl_credit_service import EtlCreditService +from app.services.page_limit_service import PageLimitService pytestmark = pytest.mark.unit @@ -25,23 +20,8 @@ _CONNECTOR_ID = 42 _SEARCH_SPACE_ID = 1 -def _micros(pages: int) -> int: - """Convert a page count to micro-USD using the configured rate.""" - return pages * config.MICROS_PER_PAGE - - -@pytest.fixture(autouse=True) -def _enable_etl_billing(monkeypatch): - """Force ETL credit billing on so the gating/charging path runs. - - It defaults to off (self-hosted/OSS), which would short-circuit - get_available_micros to None and bypass every check in this module. - """ - monkeypatch.setattr(config, "ETL_CREDIT_BILLING_ENABLED", True) - - # =================================================================== -# A) EtlCreditService.estimate_pages_from_metadata — pure function +# A) PageLimitService.estimate_pages_from_metadata — pure function # No mocks: it's a staticmethod with no I/O. # =================================================================== @@ -50,91 +30,88 @@ class TestEstimatePagesFromMetadata: """Vertical slices for the page estimation staticmethod.""" def test_pdf_100kb_returns_1(self): - assert EtlCreditService.estimate_pages_from_metadata(".pdf", 100 * 1024) == 1 + assert PageLimitService.estimate_pages_from_metadata(".pdf", 100 * 1024) == 1 def test_pdf_500kb_returns_5(self): - assert EtlCreditService.estimate_pages_from_metadata(".pdf", 500 * 1024) == 5 + assert PageLimitService.estimate_pages_from_metadata(".pdf", 500 * 1024) == 5 def test_pdf_1mb(self): - assert EtlCreditService.estimate_pages_from_metadata(".pdf", 1024 * 1024) == 10 + assert PageLimitService.estimate_pages_from_metadata(".pdf", 1024 * 1024) == 10 def test_docx_50kb_returns_1(self): - assert EtlCreditService.estimate_pages_from_metadata(".docx", 50 * 1024) == 1 + assert PageLimitService.estimate_pages_from_metadata(".docx", 50 * 1024) == 1 def test_docx_200kb(self): - assert EtlCreditService.estimate_pages_from_metadata(".docx", 200 * 1024) == 4 + assert PageLimitService.estimate_pages_from_metadata(".docx", 200 * 1024) == 4 def test_pptx_uses_200kb_per_page(self): - assert EtlCreditService.estimate_pages_from_metadata(".pptx", 600 * 1024) == 3 + assert PageLimitService.estimate_pages_from_metadata(".pptx", 600 * 1024) == 3 def test_xlsx_uses_100kb_per_page(self): - assert EtlCreditService.estimate_pages_from_metadata(".xlsx", 300 * 1024) == 3 + assert PageLimitService.estimate_pages_from_metadata(".xlsx", 300 * 1024) == 3 def test_txt_uses_3000_bytes_per_page(self): - assert EtlCreditService.estimate_pages_from_metadata(".txt", 9000) == 3 + assert PageLimitService.estimate_pages_from_metadata(".txt", 9000) == 3 def test_image_always_returns_1(self): for ext in (".jpg", ".png", ".gif", ".webp"): - assert EtlCreditService.estimate_pages_from_metadata(ext, 5_000_000) == 1 + assert PageLimitService.estimate_pages_from_metadata(ext, 5_000_000) == 1 def test_audio_uses_1mb_per_page(self): assert ( - EtlCreditService.estimate_pages_from_metadata(".mp3", 3 * 1024 * 1024) == 3 + PageLimitService.estimate_pages_from_metadata(".mp3", 3 * 1024 * 1024) == 3 ) def test_video_uses_5mb_per_page(self): assert ( - EtlCreditService.estimate_pages_from_metadata(".mp4", 15 * 1024 * 1024) == 3 + PageLimitService.estimate_pages_from_metadata(".mp4", 15 * 1024 * 1024) == 3 ) def test_unknown_ext_uses_80kb_per_page(self): - assert EtlCreditService.estimate_pages_from_metadata(".xyz", 160 * 1024) == 2 + assert PageLimitService.estimate_pages_from_metadata(".xyz", 160 * 1024) == 2 def test_zero_size_returns_1(self): - assert EtlCreditService.estimate_pages_from_metadata(".pdf", 0) == 1 + assert PageLimitService.estimate_pages_from_metadata(".pdf", 0) == 1 def test_negative_size_returns_1(self): - assert EtlCreditService.estimate_pages_from_metadata(".pdf", -500) == 1 + assert PageLimitService.estimate_pages_from_metadata(".pdf", -500) == 1 def test_minimum_is_always_1(self): - assert EtlCreditService.estimate_pages_from_metadata(".pdf", 50) == 1 + assert PageLimitService.estimate_pages_from_metadata(".pdf", 50) == 1 def test_epub_uses_50kb_per_page(self): - assert EtlCreditService.estimate_pages_from_metadata(".epub", 250 * 1024) == 5 + assert PageLimitService.estimate_pages_from_metadata(".epub", 250 * 1024) == 5 # =================================================================== -# B) Credit enforcement in connector indexers -# System boundary mocked: DB session (for EtlCreditService) +# B) Page-limit enforcement in connector indexers +# System boundary mocked: DB session (for PageLimitService) # System boundary mocked: external API clients, download/ETL -# NOT mocked: EtlCreditService itself (our own code) +# NOT mocked: PageLimitService itself (our own code) # =================================================================== class _FakeUser: """Stands in for the User ORM model at the DB boundary.""" - def __init__(self, balance_micros: int = 0, reserved_micros: int = 0): - self.credit_micros_balance = balance_micros - self.credit_micros_reserved = reserved_micros + def __init__(self, pages_used: int = 0, pages_limit: int = 100): + self.pages_used = pages_used + self.pages_limit = pages_limit -def _make_credit_session(balance_micros: int = _micros(100), reserved_micros: int = 0): - """Build a mock DB session that the real EtlCreditService can operate against. +def _make_page_limit_session(pages_used: int = 0, pages_limit: int = 100): + """Build a mock DB session that real PageLimitService can operate against. Every ``session.execute()`` returns a result compatible with both - ``get_available_micros`` (.first() → ``(balance, reserved)``) and - ``charge_credits`` (.unique().scalar_one_or_none() → User-like). + ``get_page_usage`` (.first() → tuple) and ``update_page_usage`` + (.unique().scalar_one_or_none() → User-like). """ - fake_user = _FakeUser(balance_micros, reserved_micros) + fake_user = _FakeUser(pages_used, pages_limit) session = AsyncMock() def _make_result(*_args, **_kwargs): result = MagicMock() - result.first.return_value = ( - fake_user.credit_micros_balance, - fake_user.credit_micros_reserved, - ) + result.first.return_value = (fake_user.pages_used, fake_user.pages_limit) result.unique.return_value.scalar_one_or_none.return_value = fake_user return result @@ -161,7 +138,7 @@ def gdrive_selected_mocks(monkeypatch): """Mocks for Google Drive _index_selected_files — only system boundaries.""" import app.tasks.connector_indexers.google_drive_indexer as _mod - session, fake_user = _make_credit_session(_micros(100)) + session, fake_user = _make_page_limit_session(0, 100) get_file_results: dict[str, tuple[dict | None, str | None]] = {} @@ -206,11 +183,12 @@ async def _run_gdrive_selected(mocks, file_ids): ) -async def test_gdrive_files_within_credit_are_downloaded(gdrive_selected_mocks): - """Files whose cumulative estimated cost fits within available credit +async def test_gdrive_files_within_quota_are_downloaded(gdrive_selected_mocks): + """Files whose cumulative estimated pages fit within remaining quota are sent to _download_and_index.""" m = gdrive_selected_mocks - m["fake_user"].credit_micros_balance = _micros(100) + m["fake_user"].pages_used = 0 + m["fake_user"].pages_limit = 100 for fid in ("f1", "f2", "f3"): m["get_file_results"][fid] = ( @@ -229,10 +207,11 @@ async def test_gdrive_files_within_credit_are_downloaded(gdrive_selected_mocks): assert len(call_files) == 3 -async def test_gdrive_files_exceeding_credit_rejected(gdrive_selected_mocks): - """Files whose cost would exceed available credit are rejected.""" +async def test_gdrive_files_exceeding_quota_rejected(gdrive_selected_mocks): + """Files whose pages would exceed remaining quota are rejected.""" m = gdrive_selected_mocks - m["fake_user"].credit_micros_balance = _micros(2) + m["fake_user"].pages_used = 98 + m["fake_user"].pages_limit = 100 m["get_file_results"]["big"] = ( _make_gdrive_file("big", "huge.pdf", size=500 * 1024), @@ -245,13 +224,14 @@ async def test_gdrive_files_exceeding_credit_rejected(gdrive_selected_mocks): assert indexed == 0 assert len(errors) == 1 - assert "insufficient credits" in errors[0].lower() + assert "page limit" in errors[0].lower() -async def test_gdrive_credit_mix_partial_indexing(gdrive_selected_mocks): - """3rd file pushes over available credit → only first two indexed.""" +async def test_gdrive_quota_mix_partial_indexing(gdrive_selected_mocks): + """3rd file pushes over quota → only first two indexed.""" m = gdrive_selected_mocks - m["fake_user"].credit_micros_balance = _micros(2) + m["fake_user"].pages_used = 0 + m["fake_user"].pages_limit = 2 for fid in ("f1", "f2", "f3"): m["get_file_results"][fid] = ( @@ -270,10 +250,11 @@ async def test_gdrive_credit_mix_partial_indexing(gdrive_selected_mocks): assert {f["id"] for f in call_files} == {"f1", "f2"} -async def test_gdrive_proportional_credit_deduction(gdrive_selected_mocks): - """Credit deducted is proportional to successfully indexed files.""" +async def test_gdrive_proportional_page_deduction(gdrive_selected_mocks): + """Pages deducted are proportional to successfully indexed files.""" m = gdrive_selected_mocks - m["fake_user"].credit_micros_balance = _micros(100) + m["fake_user"].pages_used = 0 + m["fake_user"].pages_limit = 100 for fid in ("f1", "f2", "f3", "f4"): m["get_file_results"][fid] = ( @@ -287,14 +268,14 @@ async def test_gdrive_proportional_credit_deduction(gdrive_selected_mocks): [("f1", "f1.xyz"), ("f2", "f2.xyz"), ("f3", "f3.xyz"), ("f4", "f4.xyz")], ) - # 4 estimated pages, 2 of 4 indexed → deduct 2 pages. - assert m["fake_user"].credit_micros_balance == _micros(100) - _micros(2) + assert m["fake_user"].pages_used == 2 async def test_gdrive_no_deduction_when_nothing_indexed(gdrive_selected_mocks): - """If batch_indexed == 0, the user's balance stays unchanged.""" + """If batch_indexed == 0, user's pages_used stays unchanged.""" m = gdrive_selected_mocks - m["fake_user"].credit_micros_balance = _micros(95) + m["fake_user"].pages_used = 5 + m["fake_user"].pages_limit = 100 m["get_file_results"]["f1"] = ( _make_gdrive_file("f1", "f1.xyz", size=80 * 1024), @@ -304,13 +285,14 @@ async def test_gdrive_no_deduction_when_nothing_indexed(gdrive_selected_mocks): await _run_gdrive_selected(m, [("f1", "f1.xyz")]) - assert m["fake_user"].credit_micros_balance == _micros(95) + assert m["fake_user"].pages_used == 5 -async def test_gdrive_zero_credit_rejects_all(gdrive_selected_mocks): - """When the balance is exhausted, every file is rejected.""" +async def test_gdrive_zero_quota_rejects_all(gdrive_selected_mocks): + """When pages_used == pages_limit, every file is rejected.""" m = gdrive_selected_mocks - m["fake_user"].credit_micros_balance = 0 + m["fake_user"].pages_used = 100 + m["fake_user"].pages_limit = 100 for fid in ("f1", "f2"): m["get_file_results"][fid] = ( @@ -335,7 +317,7 @@ async def test_gdrive_zero_credit_rejects_all(gdrive_selected_mocks): def gdrive_full_scan_mocks(monkeypatch): import app.tasks.connector_indexers.google_drive_indexer as _mod - session, fake_user = _make_credit_session(_micros(100)) + session, fake_user = _make_page_limit_session(0, 100) mock_task_logger = MagicMock() mock_task_logger.log_task_progress = AsyncMock() @@ -382,9 +364,10 @@ async def _run_gdrive_full_scan(mocks, max_files=500): ) -async def test_gdrive_full_scan_skips_over_credit(gdrive_full_scan_mocks, monkeypatch): +async def test_gdrive_full_scan_skips_over_quota(gdrive_full_scan_mocks, monkeypatch): m = gdrive_full_scan_mocks - m["fake_user"].credit_micros_balance = _micros(2) + m["fake_user"].pages_used = 0 + m["fake_user"].pages_limit = 2 page_files = [ _make_gdrive_file(f"f{i}", f"file{i}.xyz", size=80 * 1024) for i in range(5) @@ -408,7 +391,8 @@ async def test_gdrive_full_scan_deducts_after_indexing( gdrive_full_scan_mocks, monkeypatch ): m = gdrive_full_scan_mocks - m["fake_user"].credit_micros_balance = _micros(100) + m["fake_user"].pages_used = 0 + m["fake_user"].pages_limit = 100 page_files = [ _make_gdrive_file(f"f{i}", f"file{i}.xyz", size=80 * 1024) for i in range(3) @@ -424,7 +408,7 @@ async def test_gdrive_full_scan_deducts_after_indexing( await _run_gdrive_full_scan(m) - assert m["fake_user"].credit_micros_balance == _micros(100) - _micros(3) + assert m["fake_user"].pages_used == 3 # --------------------------------------------------------------------------- @@ -432,10 +416,10 @@ async def test_gdrive_full_scan_deducts_after_indexing( # --------------------------------------------------------------------------- -async def test_gdrive_delta_sync_skips_over_credit(monkeypatch): +async def test_gdrive_delta_sync_skips_over_quota(monkeypatch): import app.tasks.connector_indexers.google_drive_indexer as _mod - session, _ = _make_credit_session(_micros(2)) + session, _ = _make_page_limit_session(0, 2) changes = [ { @@ -487,7 +471,7 @@ async def test_gdrive_delta_sync_skips_over_credit(monkeypatch): # =================================================================== -# C) OneDrive smoke tests — verify credit wiring +# C) OneDrive smoke tests — verify page limit wiring # =================================================================== @@ -505,7 +489,7 @@ def _make_onedrive_file(file_id: str, name: str, size: int = 80 * 1024) -> dict: def onedrive_selected_mocks(monkeypatch): import app.tasks.connector_indexers.onedrive_indexer as _mod - session, fake_user = _make_credit_session(_micros(100)) + session, fake_user = _make_page_limit_session(0, 100) get_file_results: dict[str, tuple[dict | None, str | None]] = {} @@ -547,10 +531,11 @@ async def _run_onedrive_selected(mocks, file_ids): ) -async def test_onedrive_over_credit_rejected(onedrive_selected_mocks): - """OneDrive: files exceeding available credit produce errors, not downloads.""" +async def test_onedrive_over_quota_rejected(onedrive_selected_mocks): + """OneDrive: files exceeding quota produce errors, not downloads.""" m = onedrive_selected_mocks - m["fake_user"].credit_micros_balance = _micros(1) + m["fake_user"].pages_used = 99 + m["fake_user"].pages_limit = 100 m["get_file_results"]["big"] = ( _make_onedrive_file("big", "huge.pdf", size=500 * 1024), @@ -563,13 +548,14 @@ async def test_onedrive_over_credit_rejected(onedrive_selected_mocks): assert indexed == 0 assert len(errors) == 1 - assert "insufficient credits" in errors[0].lower() + assert "page limit" in errors[0].lower() async def test_onedrive_deducts_after_success(onedrive_selected_mocks): - """OneDrive: balance decreases after successful indexing.""" + """OneDrive: pages_used increases after successful indexing.""" m = onedrive_selected_mocks - m["fake_user"].credit_micros_balance = _micros(100) + m["fake_user"].pages_used = 0 + m["fake_user"].pages_limit = 100 for fid in ("f1", "f2"): m["get_file_results"][fid] = ( @@ -580,11 +566,11 @@ async def test_onedrive_deducts_after_success(onedrive_selected_mocks): await _run_onedrive_selected(m, [("f1", "f1.xyz"), ("f2", "f2.xyz")]) - assert m["fake_user"].credit_micros_balance == _micros(100) - _micros(2) + assert m["fake_user"].pages_used == 2 # =================================================================== -# D) Dropbox smoke tests — verify credit wiring +# D) Dropbox smoke tests — verify page limit wiring # =================================================================== @@ -604,7 +590,7 @@ def _make_dropbox_file(file_path: str, name: str, size: int = 80 * 1024) -> dict def dropbox_selected_mocks(monkeypatch): import app.tasks.connector_indexers.dropbox_indexer as _mod - session, fake_user = _make_credit_session(_micros(100)) + session, fake_user = _make_page_limit_session(0, 100) get_file_results: dict[str, tuple[dict | None, str | None]] = {} @@ -646,10 +632,11 @@ async def _run_dropbox_selected(mocks, file_paths): ) -async def test_dropbox_over_credit_rejected(dropbox_selected_mocks): - """Dropbox: files exceeding available credit produce errors, not downloads.""" +async def test_dropbox_over_quota_rejected(dropbox_selected_mocks): + """Dropbox: files exceeding quota produce errors, not downloads.""" m = dropbox_selected_mocks - m["fake_user"].credit_micros_balance = _micros(1) + m["fake_user"].pages_used = 99 + m["fake_user"].pages_limit = 100 m["get_file_results"]["/huge.pdf"] = ( _make_dropbox_file("/huge.pdf", "huge.pdf", size=500 * 1024), @@ -662,13 +649,14 @@ async def test_dropbox_over_credit_rejected(dropbox_selected_mocks): assert indexed == 0 assert len(errors) == 1 - assert "insufficient credits" in errors[0].lower() + assert "page limit" in errors[0].lower() async def test_dropbox_deducts_after_success(dropbox_selected_mocks): - """Dropbox: balance decreases after successful indexing.""" + """Dropbox: pages_used increases after successful indexing.""" m = dropbox_selected_mocks - m["fake_user"].credit_micros_balance = _micros(100) + m["fake_user"].pages_used = 0 + m["fake_user"].pages_limit = 100 for name in ("f1.xyz", "f2.xyz"): path = f"/{name}" @@ -680,4 +668,4 @@ async def test_dropbox_deducts_after_success(dropbox_selected_mocks): await _run_dropbox_selected(m, [("/f1.xyz", "f1.xyz"), ("/f2.xyz", "f2.xyz")]) - assert m["fake_user"].credit_micros_balance == _micros(100) - _micros(2) + assert m["fake_user"].pages_used == 2 diff --git a/surfsense_backend/tests/unit/notifications/service/messages/test_insufficient_credits.py b/surfsense_backend/tests/unit/notifications/service/messages/test_insufficient_credits.py deleted file mode 100644 index c5366cce2..000000000 --- a/surfsense_backend/tests/unit/notifications/service/messages/test_insufficient_credits.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Unit tests for insufficient-credits presentation logic.""" - -from __future__ import annotations - -import pytest - -from app.notifications.service.messages import insufficient_credits as msg - -pytestmark = pytest.mark.unit - - -def test_operation_id_encodes_search_space(): - """The operation id embeds the search space id.""" - assert msg.operation_id("doc.pdf", 9).startswith("insufficient_credits_9_") - - -def test_summary_title_and_message(): - """The summary states the document and the required/available credit.""" - title, message = msg.summary( - "short.pdf", balance_micros=250_000, required_micros=1_000_000 - ) - assert title == "Insufficient credits: short.pdf" - assert message == ( - "This document costs about $1.00 to process but you have " - "$0.25 of credit left. Add more credits to continue." - ) - - -def test_summary_clamps_negative_balance_to_zero(): - """A negative balance is clamped to $0.00 in the message.""" - _, message = msg.summary("doc.pdf", balance_micros=-5_000, required_micros=500_000) - assert "$0.00 of credit left" in message - - -def test_summary_truncates_long_name(): - """A long document name is truncated in the title.""" - title, _ = msg.summary("a" * 50, balance_micros=0, required_micros=1_000) - assert title == f"Insufficient credits: {'a' * 40}..." diff --git a/surfsense_backend/tests/unit/notifications/service/messages/test_page_limit.py b/surfsense_backend/tests/unit/notifications/service/messages/test_page_limit.py new file mode 100644 index 000000000..606e985f2 --- /dev/null +++ b/surfsense_backend/tests/unit/notifications/service/messages/test_page_limit.py @@ -0,0 +1,32 @@ +"""Unit tests for page-limit presentation logic.""" + +from __future__ import annotations + +import pytest + +from app.notifications.service.messages import page_limit as msg + +pytestmark = pytest.mark.unit + + +def test_operation_id_encodes_search_space(): + """The operation id embeds the search space id.""" + assert msg.operation_id("doc.pdf", 9).startswith("page_limit_9_") + + +def test_summary_title_and_message(): + """The summary states the document and the used/limit page counts.""" + title, message = msg.summary( + "short.pdf", pages_used=95, pages_limit=100, pages_to_add=10 + ) + assert title == "Page limit exceeded: short.pdf" + assert message == ( + "This document has ~10 page(s) but you've used 95/100 pages. " + "Upgrade to process more documents." + ) + + +def test_summary_truncates_long_name(): + """A long document name is truncated in the title.""" + title, _ = msg.summary("a" * 50, pages_used=1, pages_limit=2, pages_to_add=1) + assert title == f"Page limit exceeded: {'a' * 40}..." diff --git a/surfsense_backend/tests/unit/observability/test_helpers.py b/surfsense_backend/tests/unit/observability/test_helpers.py index eafb8b626..ae60c1939 100644 --- a/surfsense_backend/tests/unit/observability/test_helpers.py +++ b/surfsense_backend/tests/unit/observability/test_helpers.py @@ -31,10 +31,10 @@ def _disable_otel(monkeypatch: pytest.MonkeyPatch): ("process_file_upload_with_document", "process"), ("process_circleback_meeting", "process"), ("generate_video_presentation", "generate"), - ("podcast.draft_transcript", "podcast.draft"), - ("podcast.render_audio", "podcast.render"), + ("generate_content_podcast", "generate"), ("cleanup_stale_indexing_notifications", "cleanup"), - ("reconcile_pending_stripe_credit_purchases", "reconcile"), + ("reconcile_pending_stripe_page_purchases", "reconcile"), + ("reconcile_pending_stripe_token_purchases", "reconcile"), ("check_periodic_schedules", "check"), ("ai_sort_search_space", "ai"), ("index_notion_pages", "index"), diff --git a/surfsense_backend/tests/unit/podcasts/conftest.py b/surfsense_backend/tests/unit/podcasts/conftest.py deleted file mode 100644 index 5eb4d8457..000000000 --- a/surfsense_backend/tests/unit/podcasts/conftest.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Shared builders for podcast unit tests. - -These tests exercise pure logic through public interfaces with no test doubles: -the brief and transcript factories build valid aggregates so each test states -only the fields it cares about. Stateful, persistence-backed paths (the lifecycle -service, the Celery task bodies) are covered by the integration suite against a -real database. -""" - -from __future__ import annotations - -import pytest - -from app.podcasts.schemas import ( - DurationTarget, - PodcastSpec, - PodcastStyle, - SpeakerRole, - SpeakerSpec, - Transcript, - TranscriptTurn, -) - - -@pytest.fixture -def make_spec(): - """Factory for a valid :class:`PodcastSpec`; override only what matters.""" - - def _make( - *, - language: str = "en", - style: PodcastStyle = PodcastStyle.CONVERSATIONAL, - speakers: list[SpeakerSpec] | None = None, - min_minutes: int = 10, - max_minutes: int = 20, - focus: str | None = None, - ) -> PodcastSpec: - if speakers is None: - speakers = [ - SpeakerSpec( - slot=0, - name="Host", - role=SpeakerRole.HOST, - voice_id="kokoro:am_adam", - ), - SpeakerSpec( - slot=1, - name="Guest", - role=SpeakerRole.GUEST, - voice_id="kokoro:af_bella", - ), - ] - return PodcastSpec( - language=language, - style=style, - speakers=speakers, - duration=DurationTarget(min_minutes=min_minutes, max_minutes=max_minutes), - focus=focus, - ) - - return _make - - -@pytest.fixture -def make_transcript(): - """Factory for a valid :class:`Transcript`.""" - - def _make(turns: list[tuple[int, str]] | None = None) -> Transcript: - if turns is None: - turns = [(0, "Welcome to the show."), (1, "Glad to be here.")] - return Transcript( - turns=[TranscriptTurn(speaker=slot, text=text) for slot, text in turns] - ) - - return _make diff --git a/surfsense_backend/tests/unit/podcasts/test_api_schemas.py b/surfsense_backend/tests/unit/podcasts/test_api_schemas.py deleted file mode 100644 index 41664ac64..000000000 --- a/surfsense_backend/tests/unit/podcasts/test_api_schemas.py +++ /dev/null @@ -1,94 +0,0 @@ -"""The API read model the frontend renders from. - -``PodcastDetail.of`` maps a stored podcast row to the detail view and action -responses: it exposes the deserialized brief and transcript and a simple -``has_audio`` flag the client can't derive from the published Zero columns. Each -test builds a row in one lifecycle shape and asserts the mapping reflects it. -""" - -from __future__ import annotations - -from datetime import UTC, datetime - -import pytest - -from app.podcasts.api.schemas import PodcastDetail -from app.podcasts.persistence import Podcast, PodcastStatus - -pytestmark = pytest.mark.unit - - -def _podcast(*, status: PodcastStatus = PodcastStatus.PENDING, **columns) -> Podcast: - """A persisted-looking row: the id and created_at a saved podcast would carry.""" - podcast = Podcast( - title="Episode", - search_space_id=3, - status=status, - spec_version=1, - **columns, - ) - podcast.id = 1 - podcast.created_at = datetime.now(UTC) - return podcast - - -def test_a_fresh_podcast_exposes_no_brief_transcript_or_audio(): - detail = PodcastDetail.of(_podcast()) - - assert detail.status == PodcastStatus.PENDING - assert detail.spec is None - assert detail.transcript is None - assert detail.has_audio is False - - -def test_an_awaiting_brief_podcast_exposes_the_deserialized_brief(make_spec): - podcast = _podcast( - status=PodcastStatus.AWAITING_BRIEF, - spec=make_spec(language="fr").model_dump(mode="json"), - ) - - detail = PodcastDetail.of(podcast) - - assert detail.spec is not None - assert detail.spec.language == "fr" - - -def test_a_legacy_episode_still_exposes_its_transcript_and_audio(): - # Pre-rework rows stored [{speaker_id, dialog}] and a local file path; - # they must keep flowing through the new read model, not fail validation. - podcast = _podcast( - status=PodcastStatus.READY, - podcast_transcript=[ - {"speaker_id": 0, "dialog": "Welcome back."}, - {"speaker_id": 1, "dialog": "Glad to be here."}, - ], - file_location="/var/old/podcast.mp3", - ) - - detail = PodcastDetail.of(podcast) - - assert detail.has_audio is True - assert detail.transcript is not None - assert [(turn.speaker, turn.text) for turn in detail.transcript.turns] == [ - (0, "Welcome back."), - (1, "Glad to be here."), - ] - - -def test_a_ready_podcast_reports_available_audio(make_spec, make_transcript): - podcast = _podcast( - status=PodcastStatus.READY, - spec=make_spec().model_dump(mode="json"), - podcast_transcript=make_transcript().model_dump(mode="json"), - storage_backend="local", - storage_key="k", - duration_seconds=120, - ) - - detail = PodcastDetail.of(podcast) - - assert detail.status == PodcastStatus.READY - assert detail.has_audio is True - assert detail.duration_seconds == 120 - assert detail.transcript is not None - assert detail.error is None diff --git a/surfsense_backend/tests/unit/podcasts/test_renderer.py b/surfsense_backend/tests/unit/podcasts/test_renderer.py deleted file mode 100644 index 2bcdff967..000000000 --- a/surfsense_backend/tests/unit/podcasts/test_renderer.py +++ /dev/null @@ -1,94 +0,0 @@ -"""The renderer refuses an inconsistent spec/transcript before spending work. - -Full synthesis-and-merge needs FFmpeg and a real provider, so it belongs to an -integration test. What is pure and worth securing here is the renderer's -contract that it validates the transcript against the brief up front: a turn -naming an unknown speaker, or a speaker naming an unknown voice, fails loudly -rather than producing silent or wrong audio. The TTS provider is an external -port, faked here and never expected to be called on these paths. -""" - -from __future__ import annotations - -from pathlib import Path - -import pytest - -from app.podcasts.rendering import PodcastRenderer, RenderError -from app.podcasts.schemas import ( - DurationTarget, - PodcastSpec, - SpeakerRole, - SpeakerSpec, - Transcript, - TranscriptTurn, -) -from app.podcasts.tts import SynthesizedAudio -from app.podcasts.voices import CatalogVoice, TtsProvider, VoiceCatalog, VoiceGender - -pytestmark = pytest.mark.unit - - -class _UnusedTTS: - """A TTS port double that fails the test if it is ever asked to speak. - - These behaviors must short-circuit before synthesis, so any call here is a - regression. - """ - - @property - def container(self) -> str: - return "mp3" - - async def synthesize(self, _request): # pragma: no cover - must not run - raise AssertionError("synthesis should not be attempted") - return SynthesizedAudio(data=b"", container="mp3") - - -def _catalog_with(voice_id: str) -> VoiceCatalog: - return VoiceCatalog( - [ - CatalogVoice( - voice_id=voice_id, - provider=TtsProvider.KOKORO, - language="en-US", - display_name=voice_id, - gender=VoiceGender.MALE, - native_ref="am_adam", - ) - ] - ) - - -def _spec(voice_id: str) -> PodcastSpec: - return PodcastSpec( - language="en", - speakers=[ - SpeakerSpec(slot=0, name="Host", role=SpeakerRole.HOST, voice_id=voice_id) - ], - duration=DurationTarget(min_minutes=5, max_minutes=10), - ) - - -async def test_render_rejects_a_turn_for_an_unknown_speaker(tmp_path): - renderer = PodcastRenderer( - tts=_UnusedTTS(), catalog=_catalog_with("kokoro:am_adam") - ) - transcript = Transcript(turns=[TranscriptTurn(speaker=5, text="Who am I?")]) - - with pytest.raises(RenderError): - await renderer.render( - spec=_spec("kokoro:am_adam"), transcript=transcript, workdir=Path(tmp_path) - ) - - -async def test_render_rejects_a_speaker_whose_voice_is_not_in_the_catalog(tmp_path): - renderer = PodcastRenderer( - tts=_UnusedTTS(), catalog=_catalog_with("kokoro:am_adam") - ) - transcript = Transcript(turns=[TranscriptTurn(speaker=0, text="Hello.")]) - - with pytest.raises(RenderError): - await renderer.render( - spec=_spec("kokoro:ghost"), transcript=transcript, workdir=Path(tmp_path) - ) diff --git a/surfsense_backend/tests/unit/podcasts/test_resolution.py b/surfsense_backend/tests/unit/podcasts/test_resolution.py deleted file mode 100644 index aab44f8fb..000000000 --- a/surfsense_backend/tests/unit/podcasts/test_resolution.py +++ /dev/null @@ -1,98 +0,0 @@ -"""Default language and voice selection for a fresh brief. - -Resolution is what lets most briefs need no edits: it proposes a sensible -language and a distinct voice per speaker. These tests state the policy -("reuse what the user last chose, else English"; "two speakers should sound -like two people") through the public resolver functions and the real catalog. -We never guess the language from source content. -""" - -from __future__ import annotations - -import pytest - -from app.podcasts.resolution import ( - DEFAULT_LANGUAGE, - LanguageContext, - VoiceResolutionError, - resolve_language, - resolve_voices, -) -from app.podcasts.voices import TtsProvider, get_voice_catalog - -pytestmark = pytest.mark.unit - - -def test_last_used_language_is_reused(): - context = LanguageContext(last_used="fr") - assert resolve_language(context) == "fr" - - -def test_first_time_user_with_no_signal_gets_the_default(): - assert resolve_language(LanguageContext()) == DEFAULT_LANGUAGE - - -def test_two_speakers_get_distinct_voices(): - """A two-speaker episode should not voice both with the same person.""" - catalog = get_voice_catalog() - voices = resolve_voices( - catalog=catalog, provider=TtsProvider.KOKORO, language="en", speaker_count=2 - ) - assert len(voices) == 2 - assert voices[0].voice_id != voices[1].voice_id - - -def test_a_users_preferred_voice_is_reused_when_still_valid(): - catalog = get_voice_catalog() - voices = resolve_voices( - catalog=catalog, - provider=TtsProvider.KOKORO, - language="en", - speaker_count=2, - preferred=["kokoro:af_bella"], - ) - assert voices[0].voice_id == "kokoro:af_bella" - - -def test_a_preferred_voice_invalid_for_the_language_is_replaced(): - """A stale preference (wrong provider/language) is silently dropped.""" - catalog = get_voice_catalog() - voices = resolve_voices( - catalog=catalog, - provider=TtsProvider.KOKORO, - language="en", - speaker_count=1, - preferred=["kokoro:does-not-exist"], - ) - assert voices[0].voice_id in { - v.voice_id for v in catalog.for_provider(TtsProvider.KOKORO) - } - - -def test_resolution_fails_when_no_voice_speaks_the_language(): - """If a provider can't speak the language at all, that is surfaced loudly.""" - catalog = get_voice_catalog() - with pytest.raises(VoiceResolutionError): - resolve_voices( - catalog=catalog, - provider=TtsProvider.KOKORO, - language="xx", - speaker_count=1, - ) - - -def test_every_speaker_is_assigned_even_when_voices_run_out(): - """With one available voice, both speakers still get one rather than failing.""" - catalog = get_voice_catalog() - voices = resolve_voices( - catalog=catalog, provider=TtsProvider.KOKORO, language="fr", speaker_count=2 - ) - assert len(voices) == 2 - - -def test_speaker_count_must_be_positive(): - catalog = get_voice_catalog() - with pytest.raises(ValueError): - resolve_voices( - catalog=catalog, provider=TtsProvider.KOKORO, language="en", speaker_count=0 - ) diff --git a/surfsense_backend/tests/unit/podcasts/test_spec.py b/surfsense_backend/tests/unit/podcasts/test_spec.py deleted file mode 100644 index 4efd530e9..000000000 --- a/surfsense_backend/tests/unit/podcasts/test_spec.py +++ /dev/null @@ -1,163 +0,0 @@ -"""The brief and transcript contracts. - -A brief is what a user approves before any tokens or audio are spent, so its -validation rules are real behavior: they are the guardrails that keep a -nonsensical or ambiguous brief from ever reaching the expensive stages. These -tests pin those rules through construction of the public Pydantic models. -""" - -from __future__ import annotations - -import pytest -from pydantic import ValidationError - -from app.podcasts.schemas import ( - DurationTarget, - PodcastSpec, - PodcastStyle, - SpeakerRole, - SpeakerSpec, - Transcript, - TranscriptTurn, - normalize_language_tag, -) - -pytestmark = pytest.mark.unit - - -def _speaker(slot: int, voice_id: str = "kokoro:am_adam") -> SpeakerSpec: - return SpeakerSpec( - slot=slot, name=f"Speaker {slot}", role=SpeakerRole.HOST, voice_id=voice_id - ) - - -@pytest.mark.parametrize( - ("raw", "expected"), - [ - ("EN", "en"), - ("en-US", "en-US"), - ("PT-BR", "pt-BR"), - (" fr ", "fr"), - ], -) -def test_language_is_normalized_to_canonical_form(raw, expected): - """The primary subtag is lowercased and surrounding space trimmed.""" - assert normalize_language_tag(raw) == expected - - -@pytest.mark.parametrize("invalid", ["", "e", "english!", "123", "en_US"]) -def test_invalid_language_tags_are_rejected(invalid): - """Tags that are not BCP-47-shaped never reach a brief.""" - with pytest.raises(ValueError): - normalize_language_tag(invalid) - - -def test_spec_normalizes_its_language_on_construction(): - """A brief stores a canonical language regardless of how it was entered.""" - spec = PodcastSpec( - language="EN-us", - speakers=[_speaker(0)], - duration=DurationTarget(min_minutes=5, max_minutes=10), - ) - assert spec.language == "en-us" - - -def test_speakers_must_have_unique_slots(): - """Slots are the join key to transcript turns, so duplicates are invalid.""" - with pytest.raises(ValidationError): - PodcastSpec( - language="en", - speakers=[_speaker(0), _speaker(0, voice_id="kokoro:af_bella")], - duration=DurationTarget(min_minutes=5, max_minutes=10), - ) - - -def test_a_brief_needs_at_least_one_speaker(): - with pytest.raises(ValidationError): - PodcastSpec( - language="en", - speakers=[], - duration=DurationTarget(min_minutes=5, max_minutes=10), - ) - - -def test_a_monologue_brief_carries_exactly_one_speaker(): - spec = PodcastSpec( - language="en", - style=PodcastStyle.MONOLOGUE, - speakers=[_speaker(0)], - duration=DurationTarget(min_minutes=5, max_minutes=10), - ) - assert spec.style is PodcastStyle.MONOLOGUE - - -def test_a_monologue_brief_rejects_multiple_speakers(): - """One voice is what 'monologue' means; a second speaker is a user error.""" - with pytest.raises(ValidationError): - PodcastSpec( - language="en", - style=PodcastStyle.MONOLOGUE, - speakers=[_speaker(0), _speaker(1, voice_id="kokoro:af_bella")], - duration=DurationTarget(min_minutes=5, max_minutes=10), - ) - - -def test_duration_rejects_an_inverted_range(): - """A max below the min is a user error caught at the brief gate.""" - with pytest.raises(ValidationError): - DurationTarget(min_minutes=20, max_minutes=10) - - -def test_duration_midpoint_is_where_drafting_aims(): - assert DurationTarget(min_minutes=10, max_minutes=20).midpoint_minutes == 15 - - -def test_blank_focus_becomes_absent(): - """Whitespace-only steer is treated as no steer.""" - spec = PodcastSpec( - language="en", - speakers=[_speaker(0)], - duration=DurationTarget(min_minutes=5, max_minutes=10), - focus=" ", - ) - assert spec.focus is None - - -def test_speaker_for_returns_the_speaker_bound_to_a_slot(): - spec = PodcastSpec( - language="en", - speakers=[_speaker(0), _speaker(1, voice_id="kokoro:af_bella")], - duration=DurationTarget(min_minutes=5, max_minutes=10), - ) - assert spec.speaker_for(1).voice_id == "kokoro:af_bella" - - -def test_speaker_for_raises_when_no_speaker_matches(): - spec = PodcastSpec( - language="en", - speakers=[_speaker(0)], - duration=DurationTarget(min_minutes=5, max_minutes=10), - ) - with pytest.raises(KeyError): - spec.speaker_for(99) - - -def test_transcript_word_count_sums_spoken_words(): - """Word count is what drafting checks runtime against, so it must be exact.""" - transcript = Transcript( - turns=[ - TranscriptTurn(speaker=0, text="hello there world"), - TranscriptTurn(speaker=1, text="one two"), - ] - ) - assert transcript.word_count == 5 - - -def test_blank_transcript_turns_are_rejected(): - with pytest.raises(ValidationError): - TranscriptTurn(speaker=0, text=" ") - - -def test_a_transcript_needs_at_least_one_turn(): - with pytest.raises(ValidationError): - Transcript(turns=[]) diff --git a/surfsense_backend/tests/unit/podcasts/test_structured.py b/surfsense_backend/tests/unit/podcasts/test_structured.py deleted file mode 100644 index 8d7b2226a..000000000 --- a/surfsense_backend/tests/unit/podcasts/test_structured.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Parsing a model's reply into a structured shape. - -Agent LLMs wrap JSON in prose and markdown fences. ``invoke_json`` exists so -every generation node tolerates that the same way. The LLM is an external -boundary, so it is faked with a canned reply; the behavior under test is the -parsing, not the model. -""" - -from __future__ import annotations - -import pytest -from pydantic import BaseModel - -from app.podcasts.generation.structured import StructuredOutputError, invoke_json - -pytestmark = pytest.mark.unit - - -class _Shape(BaseModel): - name: str - count: int - - -class _CannedLLM: - """A TTS-free stand-in for the chat model: replies with one fixed string.""" - - def __init__(self, reply: str) -> None: - self._reply = reply - - async def ainvoke(self, _messages): - return SimpleReply(self._reply) - - -class SimpleReply: - def __init__(self, content: str) -> None: - self.content = content - - -async def _parse(reply: str) -> _Shape: - return await invoke_json(_CannedLLM(reply), [], _Shape) - - -async def test_parses_a_clean_json_reply(): - shape = await _parse('{"name": "alpha", "count": 3}') - assert shape == _Shape(name="alpha", count=3) - - -async def test_parses_json_wrapped_in_a_markdown_fence(): - reply = '```json\n{"name": "beta", "count": 7}\n```' - shape = await _parse(reply) - assert shape == _Shape(name="beta", count=7) - - -async def test_extracts_json_embedded_in_prose(): - """Reasoning models prepend/append chatter around the object.""" - reply = 'Sure, here you go: {"name": "gamma", "count": 1} — hope that helps!' - shape = await _parse(reply) - assert shape == _Shape(name="gamma", count=1) - - -async def test_raises_when_there_is_no_json_object(): - with pytest.raises(StructuredOutputError): - await _parse("I could not produce that.") - - -async def test_raises_when_the_json_does_not_match_the_shape(): - with pytest.raises(StructuredOutputError): - await _parse('{"name": "delta"}') diff --git a/surfsense_backend/tests/unit/podcasts/test_voice_catalog.py b/surfsense_backend/tests/unit/podcasts/test_voice_catalog.py deleted file mode 100644 index 861d8768c..000000000 --- a/surfsense_backend/tests/unit/podcasts/test_voice_catalog.py +++ /dev/null @@ -1,105 +0,0 @@ -"""The voice catalog and provider identification. - -The catalog is the single source of truth for which voices exist; resolution, -the API picker, and the renderer all depend on its lookups behaving correctly. -These tests build a small catalog of their own so they assert on the lookup -behavior, not on which specific voices ship. -""" - -from __future__ import annotations - -import pytest - -from app.podcasts.voices import ( - ANY_LANGUAGE, - CatalogVoice, - TtsProvider, - VoiceCatalog, - VoiceGender, - provider_from_service, -) - -pytestmark = pytest.mark.unit - - -def _voice( - voice_id: str, - *, - provider: TtsProvider = TtsProvider.KOKORO, - language: str = "en-US", - gender: VoiceGender = VoiceGender.MALE, -) -> CatalogVoice: - return CatalogVoice( - voice_id=voice_id, - provider=provider, - language=language, - display_name=voice_id, - gender=gender, - native_ref=voice_id, - ) - - -def test_for_provider_returns_only_that_providers_voices(): - catalog = VoiceCatalog( - [ - _voice("k1", provider=TtsProvider.KOKORO), - _voice("o1", provider=TtsProvider.OPENAI), - ] - ) - assert [v.voice_id for v in catalog.for_provider(TtsProvider.KOKORO)] == ["k1"] - - -def test_for_language_matches_on_the_primary_subtag(): - """A request for 'en' should match an 'en-US' voice (region-insensitive).""" - catalog = VoiceCatalog([_voice("k1", language="en-US")]) - assert [v.voice_id for v in catalog.for_language(TtsProvider.KOKORO, "en")] == [ - "k1" - ] - - -def test_for_language_excludes_other_languages(): - catalog = VoiceCatalog([_voice("k1", language="en-US")]) - assert catalog.for_language(TtsProvider.KOKORO, "fr") == [] - - -def test_an_any_language_voice_speaks_every_language(): - """Provider-agnostic voices (e.g. OpenAI) match whatever the text is in.""" - voice = _voice("o1", provider=TtsProvider.OPENAI, language=ANY_LANGUAGE) - assert voice.speaks("ja") - assert voice.speaks("pt-BR") - - -def test_supports_language_reports_availability(): - catalog = VoiceCatalog([_voice("k1", language="en-US")]) - assert catalog.supports_language(TtsProvider.KOKORO, "en") - assert not catalog.supports_language(TtsProvider.KOKORO, "de") - - -def test_get_raises_for_an_unknown_voice(): - catalog = VoiceCatalog([_voice("k1")]) - with pytest.raises(KeyError): - catalog.get("nope") - - -def test_a_catalog_rejects_duplicate_voice_ids(): - """Stored ids must be unique so a brief's voice_id resolves unambiguously.""" - with pytest.raises(ValueError): - VoiceCatalog([_voice("dup"), _voice("dup")]) - - -@pytest.mark.parametrize( - ("service", "expected"), - [ - ("openai/tts-1", TtsProvider.OPENAI), - ("azure/neural", TtsProvider.AZURE), - ("vertex_ai/some-model", TtsProvider.VERTEX_AI), - ("local/kokoro", TtsProvider.KOKORO), - ], -) -def test_provider_is_identified_from_the_config_string(service, expected): - assert provider_from_service(service) == expected - - -def test_unknown_provider_prefix_is_rejected(): - with pytest.raises(ValueError): - provider_from_service("madeup/model") diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index 5c5c90283..d1af29aeb 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -90,7 +90,7 @@ async def test_auto_first_turn_pins_one_model(monkeypatch): return _FakeQuotaResult(allowed=True) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _allowed, ) @@ -138,7 +138,7 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch): return _FakeQuotaResult(allowed=True) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _allowed, ) @@ -196,7 +196,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch): return _FakeQuotaResult(allowed=True) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _allowed, ) @@ -232,11 +232,11 @@ async def test_next_turn_reuses_existing_pin(monkeypatch): async def _must_not_call(*_args, **_kwargs): raise AssertionError( - "credit_get_usage should not be called for valid pin reuse" + "premium_get_usage should not be called for valid pin reuse" ) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _must_not_call, ) @@ -275,7 +275,7 @@ async def test_premium_eligible_auto_can_pin_premium(monkeypatch): return _FakeQuotaResult(allowed=True) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _allowed, ) @@ -320,7 +320,7 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch): return _FakeQuotaResult(allowed=False) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _blocked, ) @@ -365,7 +365,7 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): return _FakeQuotaResult(allowed=False) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _blocked, ) @@ -410,7 +410,7 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): return _FakeQuotaResult(allowed=False) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _blocked, ) @@ -470,7 +470,7 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): return _FakeQuotaResult(allowed=False) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _blocked, ) @@ -529,7 +529,7 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch): return _FakeQuotaResult(allowed=False) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _blocked, ) @@ -581,7 +581,7 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch): return _FakeQuotaResult(allowed=True) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _allowed, ) @@ -633,7 +633,7 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch return _FakeQuotaResult(allowed=False) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _blocked, ) @@ -686,7 +686,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch): return _FakeQuotaResult(allowed=True) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _allowed, ) @@ -754,7 +754,7 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch): return _FakeQuotaResult(allowed=True) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _allowed, ) @@ -803,10 +803,10 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch): ) async def _must_not_call(*_args, **_kwargs): - raise AssertionError("credit_get_usage should not run on pin reuse") + raise AssertionError("premium_get_usage should not run on pin reuse") monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _must_not_call, ) @@ -864,7 +864,7 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): return _FakeQuotaResult(allowed=False) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _blocked, ) @@ -904,10 +904,10 @@ async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch): ) async def _must_not_call(*_args, **_kwargs): - raise AssertionError("credit_get_usage should not run on healthy pin reuse") + raise AssertionError("premium_get_usage should not run on healthy pin reuse") monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _must_not_call, ) @@ -962,7 +962,7 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa return _FakeQuotaResult(allowed=False) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _blocked, ) diff --git a/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py index 3ca5c7a67..0e19b80e4 100644 --- a/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py +++ b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py @@ -114,7 +114,7 @@ async def test_image_turn_filters_out_text_only_candidates(monkeypatch): [_text_only_cfg(-1), _vision_cfg(-2)], ) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _premium_allowed, ) @@ -146,7 +146,7 @@ async def test_image_turn_force_repins_stale_text_only_pin(monkeypatch): [_text_only_cfg(-1), _vision_cfg(-2)], ) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _premium_allowed, ) @@ -178,7 +178,7 @@ async def test_image_turn_reuses_existing_vision_pin(monkeypatch): [_text_only_cfg(-1), _vision_cfg(-2), _vision_cfg(-3, quality=70)], ) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _premium_allowed, ) @@ -209,7 +209,7 @@ async def test_image_turn_with_no_vision_candidates_raises(monkeypatch): [_text_only_cfg(-1), _text_only_cfg(-2)], ) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _premium_allowed, ) @@ -237,7 +237,7 @@ async def test_non_image_turn_keeps_text_only_in_pool(monkeypatch): [_text_only_cfg(-1)], ) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _premium_allowed, ) @@ -271,7 +271,7 @@ async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch): } monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [cfg_unannotated_vision]) monkeypatch.setattr( - "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", _premium_allowed, ) diff --git a/surfsense_backend/tests/unit/services/test_billable_call.py b/surfsense_backend/tests/unit/services/test_billable_call.py index 8e2c2f1da..c820724ed 100644 --- a/surfsense_backend/tests/unit/services/test_billable_call.py +++ b/surfsense_backend/tests/unit/services/test_billable_call.py @@ -38,13 +38,11 @@ class _FakeQuotaResult: used: int = 0, limit: int = 5_000_000, remaining: int = 5_000_000, - balance: int = 5_000_000, ) -> None: self.allowed = allowed self.used = used self.limit = limit self.remaining = remaining - self.balance = balance class _FakeSession: @@ -120,17 +118,17 @@ def _patch_isolation_layer( return object() monkeypatch.setattr( - "app.services.billable_calls.TokenQuotaService.credit_reserve", + "app.services.billable_calls.TokenQuotaService.premium_reserve", _fake_reserve, raising=False, ) monkeypatch.setattr( - "app.services.billable_calls.TokenQuotaService.credit_finalize", + "app.services.billable_calls.TokenQuotaService.premium_finalize", _fake_finalize, raising=False, ) monkeypatch.setattr( - "app.services.billable_calls.TokenQuotaService.credit_release", + "app.services.billable_calls.TokenQuotaService.premium_release", _fake_release, raising=False, ) @@ -203,7 +201,9 @@ async def test_premium_reserve_denied_raises_quota_insufficient(monkeypatch): spies = _patch_isolation_layer( monkeypatch, - reserve_result=_FakeQuotaResult(allowed=False, balance=0, remaining=0), + reserve_result=_FakeQuotaResult( + allowed=False, used=5_000_000, limit=5_000_000, remaining=0 + ), ) user_id = uuid4() @@ -220,7 +220,8 @@ async def test_premium_reserve_denied_raises_quota_insufficient(monkeypatch): err = exc_info.value assert err.usage_type == "image_generation" - assert err.balance_micros == 0 + assert err.used_micros == 5_000_000 + assert err.limit_micros == 5_000_000 assert err.remaining_micros == 0 # Reserve was attempted, but no finalize/release on a denied reserve # — the reservation never actually held credit. @@ -531,7 +532,7 @@ async def test_premium_video_denial_raises_quota_insufficient(monkeypatch): spies = _patch_isolation_layer( monkeypatch, reserve_result=_FakeQuotaResult( - allowed=False, balance=500_000, remaining=500_000 + allowed=False, used=4_500_000, limit=5_000_000, remaining=500_000 ), ) user_id = uuid4() @@ -551,7 +552,6 @@ async def test_premium_video_denial_raises_quota_insufficient(monkeypatch): err = exc_info.value assert err.usage_type == "video_presentation_generation" - assert err.balance_micros == 500_000 assert err.remaining_micros == 500_000 assert spies["reserve"][0]["reserve_micros"] == 1_000_000 assert spies["finalize"] == [] diff --git a/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py b/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py index 2342dd8da..a5bb3f58a 100644 --- a/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py +++ b/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py @@ -239,18 +239,17 @@ def test_video_presentation_task_uses_runner_helper() -> None: ) -def test_podcast_tasks_use_runner_helper() -> None: - """Symmetric assertion for the podcast tasks — same root cause, same +def test_podcast_task_uses_runner_helper() -> None: + """Symmetric assertion for the podcast task — same root cause, same fix, same regression risk. """ import inspect - from app.podcasts.tasks import draft, render + from app.tasks.celery_tasks import podcast_tasks - for module in (draft, render): - src = inspect.getsource(module) - assert "run_async_celery_task" in src - assert "asyncio.new_event_loop" not in src + src = inspect.getsource(podcast_tasks) + assert "run_async_celery_task" in src + assert "asyncio.new_event_loop" not in src def test_runner_runs_shutdown_asyncgens_before_close() -> None: diff --git a/surfsense_backend/tests/unit/tasks/test_podcast_billing.py b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py new file mode 100644 index 000000000..699297df1 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py @@ -0,0 +1,388 @@ +"""Unit tests for podcast Celery task billing integration. + +Validates ``_generate_content_podcast`` correctly wraps +``podcaster_graph.ainvoke`` in a ``billable_call`` envelope, propagates the +search-space owner's billing decision, and degrades cleanly when the +resolver fails or premium credit is exhausted. + +Coverage: + +* Happy-path free config: resolver → ``billable_call`` enters with + ``usage_type='podcast_generation'`` and the configured reserve override, + graph runs, podcast row flips to ``READY``. +* Happy-path premium config: same wiring with ``billing_tier='premium'``. +* Quota denial: ``billable_call`` raises ``QuotaInsufficientError`` → + graph is *not* invoked, podcast row flips to ``FAILED``, return dict + carries ``reason='premium_quota_exhausted'``. +* Resolver failure: ``ValueError`` from the resolver → podcast row flips + to ``FAILED``, return dict carries ``reason='billing_resolution_failed'``. +""" + +from __future__ import annotations + +import contextlib +from types import SimpleNamespace +from typing import Any +from uuid import uuid4 + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +class _FakeExecResult: + def __init__(self, obj): + self._obj = obj + + def scalars(self): + return self + + def first(self): + return self._obj + + def filter(self, *_args, **_kwargs): + return self + + +class _FakeSession: + def __init__(self, podcast): + self._podcast = podcast + self.commit_count = 0 + + async def execute(self, _stmt): + return _FakeExecResult(self._podcast) + + async def commit(self): + self.commit_count += 1 + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + return None + + +class _FakeSessionMaker: + def __init__(self, session: _FakeSession): + self._session = session + + def __call__(self): + return self._session + + +def _make_podcast(podcast_id: int = 7, thread_id: int = 99) -> SimpleNamespace: + """Stand-in for a ``Podcast`` row. Importing ``PodcastStatus`` lazily + inside helpers keeps this fixture cheap.""" + return SimpleNamespace( + id=podcast_id, + title="Test Podcast", + thread_id=thread_id, + status=None, + podcast_transcript=None, + file_location=None, + ) + + +@contextlib.asynccontextmanager +async def _ok_billable_call(**kwargs): + """Stand-in for ``billable_call`` that records its kwargs and yields a + no-op accumulator-shaped object.""" + _CALL_LOG.append(kwargs) + yield SimpleNamespace() + + +_CALL_LOG: list[dict[str, Any]] = [] + + +@contextlib.asynccontextmanager +async def _denying_billable_call(**kwargs): + from app.services.billable_calls import QuotaInsufficientError + + _CALL_LOG.append(kwargs) + raise QuotaInsufficientError( + usage_type=kwargs.get("usage_type", "?"), + used_micros=5_000_000, + limit_micros=5_000_000, + remaining_micros=0, + ) + yield SimpleNamespace() # pragma: no cover — for grammar only + + +@contextlib.asynccontextmanager +async def _settlement_failing_billable_call(**kwargs): + from app.services.billable_calls import BillingSettlementError + + _CALL_LOG.append(kwargs) + yield SimpleNamespace() + raise BillingSettlementError( + usage_type=kwargs.get("usage_type", "?"), + user_id=kwargs["user_id"], + cause=RuntimeError("finalize failed"), + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _reset_call_log(): + _CALL_LOG.clear() + yield + _CALL_LOG.clear() + + +@pytest.mark.asyncio +async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch): + """Happy path: free billing tier still wraps the graph call so the + audit row is recorded. Verifies kwargs threading.""" + from app.config import config as app_config + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=7, thread_id=99) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + user_id = uuid4() + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + assert search_space_id == 555 + assert thread_id == 99 + return user_id, "free", "openrouter/some-free-model" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call) + + async def _fake_graph_invoke(state, config): + return { + "podcast_transcript": [ + SimpleNamespace(speaker_id=0, dialog="Hi"), + SimpleNamespace(speaker_id=1, dialog="Hello"), + ], + "final_podcast_file_path": "/tmp/podcast.wav", + } + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=7, + source_content="hello world", + search_space_id=555, + user_prompt="make it short", + ) + + assert result["status"] == "ready" + assert result["podcast_id"] == 7 + assert podcast.status == PodcastStatus.READY + assert podcast.file_location == "/tmp/podcast.wav" + + assert len(_CALL_LOG) == 1 + call = _CALL_LOG[0] + assert call["user_id"] == user_id + assert call["search_space_id"] == 555 + assert call["billing_tier"] == "free" + assert call["base_model"] == "openrouter/some-free-model" + assert call["usage_type"] == "podcast_generation" + assert ( + call["quota_reserve_micros_override"] + == app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS + ) + # Background artifact audit rows intentionally omit the TokenUsage.thread_id + # FK to avoid coupling Celery audit commits to an active chat transaction. + assert "thread_id" not in call + assert call["call_details"] == { + "podcast_id": 7, + "title": "Test Podcast", + "thread_id": 99, + } + assert callable(call["billable_session_factory"]) + + +@pytest.mark.asyncio +async def test_billable_call_invoked_with_premium_tier(monkeypatch): + """Premium resolution flows through to ``billable_call`` so the + reserve/finalize path triggers.""" + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast() + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + user_id = uuid4() + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return user_id, "premium", "gpt-5.4" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call) + + async def _fake_graph_invoke(state, config): + return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + await podcast_tasks._generate_content_podcast( + podcast_id=7, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert _CALL_LOG[0]["billing_tier"] == "premium" + assert _CALL_LOG[0]["base_model"] == "gpt-5.4" + + +@pytest.mark.asyncio +async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypatch): + """When ``billable_call`` denies the reservation, the graph never + runs and the podcast row flips to FAILED with the documented reason + code.""" + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=8) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return uuid4(), "premium", "gpt-5.4" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr(podcast_tasks, "billable_call", _denying_billable_call) + + graph_invoked = [] + + async def _fake_graph_invoke(state, config): + graph_invoked.append(True) + return {} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=8, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "podcast_id": 8, + "reason": "premium_quota_exhausted", + } + assert podcast.status == PodcastStatus.FAILED + assert graph_invoked == [] # Graph never ran on denied reservation. + + +@pytest.mark.asyncio +async def test_billing_settlement_failure_marks_podcast_failed(monkeypatch): + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=10) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return uuid4(), "premium", "gpt-5.4" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr( + podcast_tasks, "billable_call", _settlement_failing_billable_call + ) + + async def _fake_graph_invoke(state, config): + return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=10, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "podcast_id": 10, + "reason": "billing_settlement_failed", + } + assert podcast.status == PodcastStatus.FAILED + + +@pytest.mark.asyncio +async def test_resolver_failure_marks_podcast_failed(monkeypatch): + """If the resolver raises (e.g. search-space deleted), the task fails + cleanly without invoking the graph.""" + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=9) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _failing_resolver(sess, search_space_id, *, thread_id=None): + raise ValueError("Search space 555 not found") + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _failing_resolver + ) + + graph_invoked = [] + + async def _fake_graph_invoke(state, config): + graph_invoked.append(True) + return {} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=9, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "podcast_id": 9, + "reason": "billing_resolution_failed", + } + assert podcast.status == PodcastStatus.FAILED + assert graph_invoked == [] diff --git a/surfsense_backend/uv.lock b/surfsense_backend/uv.lock index 182b9679f..a927a928d 100644 --- a/surfsense_backend/uv.lock +++ b/surfsense_backend/uv.lock @@ -15,9 +15,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'linux' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -31,9 +28,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform == 'linux' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -47,9 +41,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -63,9 +54,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'emscripten' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -79,9 +67,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -95,9 +80,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -111,9 +93,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'emscripten' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -127,9 +106,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -143,9 +119,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -159,9 +132,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra == 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -179,10 +149,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'linux' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -196,9 +162,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'linux' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -212,9 +175,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform == 'linux' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -228,9 +188,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -244,9 +201,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'emscripten' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -260,9 +214,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -276,9 +227,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -292,9 +240,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'emscripten' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -308,9 +253,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -324,9 +266,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -340,9 +279,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra == 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -360,10 +296,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'linux' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -377,9 +309,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'linux' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -393,9 +322,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform == 'linux' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -409,9 +335,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'win32' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -425,9 +348,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'emscripten' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -441,9 +361,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -457,9 +374,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'win32' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -473,9 +387,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'emscripten' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -489,9 +400,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -505,9 +413,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -521,9 +426,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform == 'win32' and extra == 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -541,10 +443,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -558,9 +456,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform == 'emscripten' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -586,12 +481,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -605,9 +494,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -621,9 +507,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform == 'emscripten' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -649,12 +532,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version == '3.13.*' and sys_platform != 'emscripten' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -680,12 +557,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform != 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", "python_version < '0'", "python_version < '0'", @@ -699,9 +570,6 @@ resolution-markers = [ "python_version < '0'", "python_version < '0'", "python_version < '0'", - "python_version < '0'", - "python_version < '0'", - "python_version < '0'", "python_full_version < '3.13' and sys_platform == 'win32' and extra != 'extra-16-surf-new-backend-cpu' and extra != 'extra-16-surf-new-backend-cu126' and extra != 'extra-16-surf-new-backend-cu128'", ] conflicts = [[ @@ -9614,7 +9482,7 @@ wheels = [ [[package]] name = "surf-new-backend" -version = "0.0.28" +version = "0.0.27" source = { editable = "." } dependencies = [ { name = "alembic" }, diff --git a/surfsense_browser_extension/package.json b/surfsense_browser_extension/package.json index e7f0f082c..959e0b395 100644 --- a/surfsense_browser_extension/package.json +++ b/surfsense_browser_extension/package.json @@ -1,7 +1,7 @@ { "name": "surfsense_browser_extension", "displayName": "Surfsense Browser Extension", - "version": "0.0.28", + "version": "0.0.27", "description": "Extension to collect Browsing History for SurfSense.", "author": "https://github.com/MODSetter", "engines": { diff --git a/surfsense_desktop/package.json b/surfsense_desktop/package.json index f4cc9586d..433e33315 100644 --- a/surfsense_desktop/package.json +++ b/surfsense_desktop/package.json @@ -1,7 +1,7 @@ { "name": "surfsense-desktop", "productName": "SurfSense", - "version": "0.0.28", + "version": "0.0.27", "description": "SurfSense Desktop App", "main": "dist/main.js", "scripts": { diff --git a/surfsense_web/app/(home)/changelog/page.tsx b/surfsense_web/app/(home)/changelog/page.tsx index b7aa14d20..42bac512a 100644 --- a/surfsense_web/app/(home)/changelog/page.tsx +++ b/surfsense_web/app/(home)/changelog/page.tsx @@ -3,7 +3,10 @@ import type { MDXComponents } from "mdx/types"; import type { Metadata } from "next"; import type { ComponentType } from "react"; import { changelog } from "@/.source/server"; -import { ChangelogTimeline, type ChangelogTimelineEntry } from "@/components/ui/changelog-timeline"; +import { + ChangelogTimeline, + type ChangelogTimelineEntry, +} from "@/components/ui/changelog-timeline"; import { formatDate } from "@/lib/utils"; import { getMDXComponents } from "@/mdx-components"; diff --git a/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx index 8ea4c1d7d..b4ec015b7 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/buy-more/page.tsx @@ -1,15 +1,48 @@ "use client"; -import { AutoReloadSettings } from "@/components/settings/auto-reload-settings"; -import { BuyCreditsContent } from "@/components/settings/buy-credits-content"; +import { useState } from "react"; +import { BuyPagesContent } from "@/components/settings/buy-pages-content"; +import { BuyTokensContent } from "@/components/settings/buy-tokens-content"; +import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; + +const TABS = [ + { id: "pages", label: "Pages" }, + { id: "tokens", label: "Premium Credit" }, +] as const; + +type TabId = (typeof TABS)[number]["id"]; export default function BuyMorePage() { + const [activeTab, setActiveTab] = useState<TabId>("pages"); + return ( - <div className="flex min-h-[37rem] w-full select-none items-center justify-center py-8"> - <div className="w-full max-w-md space-y-8"> - <BuyCreditsContent /> - <AutoReloadSettings /> - </div> + <div className="w-full select-none"> + <Tabs + value={activeTab} + onValueChange={(value) => { + setActiveTab(value as TabId); + }} + className="relative min-h-[37rem] w-full" + > + <TabsList className="absolute top-20 left-1/2 -translate-x-1/2 rounded-xl bg-accent p-1"> + {TABS.map((tab) => ( + <TabsTrigger + key={tab.id} + value={tab.id} + className="h-8 rounded-lg px-4 text-sm font-semibold text-accent-foreground transition-colors hover:bg-transparent hover:text-white data-[state=active]:bg-[#4a4a4a] data-[state=active]:text-white data-[state=active]:shadow-none" + > + {tab.label} + </TabsTrigger> + ))} + </TabsList> + + <TabsContent value="pages" className="mt-0 flex min-h-[37rem] items-center pt-14"> + <BuyPagesContent /> + </TabsContent> + <TabsContent value="tokens" className="mt-0 flex min-h-[37rem] items-center pt-14"> + <BuyTokensContent /> + </TabsContent> + </Tabs> </div> ); } diff --git a/surfsense_web/app/dashboard/[search_space_id]/earn-credits/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/earn-credits/page.tsx deleted file mode 100644 index 3ff4c3cf8..000000000 --- a/surfsense_web/app/dashboard/[search_space_id]/earn-credits/page.tsx +++ /dev/null @@ -1,11 +0,0 @@ -"use client"; - -import { EarnCreditsContent } from "@/components/settings/earn-credits-content"; - -export default function EarnCreditsPage() { - return ( - <div className="w-full select-none space-y-6"> - <EarnCreditsContent /> - </div> - ); -} diff --git a/surfsense_web/app/dashboard/[search_space_id]/more-pages/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/more-pages/page.tsx index 46f1965d0..4b3301b9f 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/more-pages/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/more-pages/page.tsx @@ -1,18 +1,11 @@ "use client"; -import { useParams, useRouter } from "next/navigation"; -import { useEffect } from "react"; +import { MorePagesContent } from "@/components/settings/more-pages-content"; -// Legacy route kept as a redirect: older "insufficient credits" notifications -// and bookmarks may still point at /more-pages. export default function MorePagesPage() { - const router = useRouter(); - const params = useParams(); - const searchSpaceId = params?.search_space_id ?? ""; - - useEffect(() => { - router.replace(`/dashboard/${searchSpaceId}/earn-credits`); - }, [router, searchSpaceId]); - - return null; + return ( + <div className="w-full select-none space-y-6"> + <MorePagesContent /> + </div> + ); } diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index f048376cc..75cfa4184 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -77,6 +77,11 @@ import { convertToThreadMessage, reconcileInterruptedAssistantMessages, } from "@/lib/chat/message-utils"; +import { + isPodcastGenerating, + looksLikePodcastRequest, + setActivePodcastTaskId, +} from "@/lib/chat/podcast-state"; import { createStreamFlushHelpers } from "@/lib/chat/stream-flush"; import { consumeSseEvents, processSharedStreamEvent } from "@/lib/chat/stream-pipeline"; import { @@ -753,9 +758,6 @@ export default function NewChatPage() { const loadedMessages = reconcileInterruptedAssistantMessages(messagesResponse.messages).map( convertToThreadMessage ); - if (messages.length > 0 && loadedMessages.length < messages.length) { - return; - } setMessages(loadedMessages); tokenUsageStore.clear(); @@ -776,7 +778,6 @@ export default function NewChatPage() { }, [ activeThreadId, isRunning, - messages.length, setMessageDocumentsMap, threadMessagesQuery.data, tokenUsageStore, @@ -953,6 +954,11 @@ export default function NewChatPage() { if (!userQuery.trim() && userImages.length === 0) return; + if (userQuery.trim() && isPodcastGenerating() && looksLikePodcastRequest(userQuery)) { + toast.warning("A podcast is already being generated."); + return; + } + const token = getBearerToken(); if (!token) { toast.error("Not authenticated. Please log in again."); @@ -1212,6 +1218,17 @@ export default function NewChatPage() { recentCancelRequestedAtRef.current = Date.now(); } }, + onToolOutputAvailable: (event, sharedCtx) => { + if (event.output?.status === "pending" && event.output?.podcast_id) { + const idx = sharedCtx.toolCallIndices.get(event.toolCallId); + if (idx !== undefined) { + const part = sharedCtx.contentPartsState.contentParts[idx]; + if (part?.type === "tool-call" && part.toolName === "generate_podcast") { + setActivePodcastTaskId(String(event.output.podcast_id)); + } + } + } + }, }) ) { return; @@ -2170,6 +2187,17 @@ export default function NewChatPage() { recentCancelRequestedAtRef.current = Date.now(); } }, + onToolOutputAvailable: (event, sharedCtx) => { + if (event.output?.status === "pending" && event.output?.podcast_id) { + const idx = sharedCtx.toolCallIndices.get(event.toolCallId); + if (idx !== undefined) { + const part = sharedCtx.contentPartsState.contentParts[idx]; + if (part?.type === "tool-call" && part.toolName === "generate_podcast") { + setActivePodcastTaskId(String(event.output.podcast_id)); + } + } + } + }, }) ) { return; @@ -2541,7 +2569,7 @@ export default function NewChatPage() { > <div key={searchSpaceId} className="flex h-full overflow-hidden"> <div className="relative flex-1 flex flex-col min-w-0 overflow-hidden"> - <Thread hasActiveThread={!!activeThreadId} /> + <Thread /> {isThreadMessagesLoading ? ( <div className="absolute inset-0 z-10 bg-panel"> <ThreadMessagesSkeleton /> diff --git a/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx index a8a88c5a5..8eaec3e5a 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx @@ -112,7 +112,9 @@ export default function PurchaseSuccessPage() { {state.kind === "still_pending" && "Your payment is still being processed by your bank. We'll apply your purchase as soon as it clears — usually within a few minutes. You can safely close this page."} {state.kind === "completed" && - `Added ${formatCredit(state.data.credit_micros_granted ?? 0)} of credit to your account.`} + (state.data.purchase_type === "page_packs" + ? `Added ${formatNumber(state.data.pages_granted ?? 0)} pages to your account.` + : `Added ${formatCredit(state.data.premium_credit_micros_granted ?? 0)} of premium credit to your account.`)} {state.kind === "failed" && "Stripe reported the checkout as failed or expired. Your card was not charged."} {state.kind === "error" && @@ -121,9 +123,18 @@ export default function PurchaseSuccessPage() { </CardDescription> </CardHeader> <CardContent className="space-y-3 text-center"> - {state.kind === "completed" && ( + {state.kind === "completed" && state.data.purchase_type === "page_packs" && ( <p className="text-sm text-muted-foreground"> - New credit balance: {formatCredit(state.data.credit_micros_balance ?? 0)} + New balance: {formatNumber(state.data.pages_limit ?? 0)} total pages + {typeof state.data.pages_used === "number" + ? ` (${formatNumber((state.data.pages_limit ?? 0) - state.data.pages_used)} remaining)` + : ""} + </p> + )} + {state.kind === "completed" && state.data.purchase_type === "premium_tokens" && ( + <p className="text-sm text-muted-foreground"> + New premium credit balance:{" "} + {formatCredit(state.data.premium_credit_micros_limit ?? 0)} </p> )} {state.kind === "error" && ( @@ -135,7 +146,7 @@ export default function PurchaseSuccessPage() { <Link href={`/dashboard/${searchSpaceId}/new-chat`}>Back to Dashboard</Link> </Button> <Button asChild variant="outline" className="w-full"> - <Link href={`/dashboard/${searchSpaceId}/buy-more`}>Buy credits</Link> + <Link href={`/dashboard/${searchSpaceId}/buy-more`}>Buy More</Link> </Button> </CardFooter> </Card> @@ -143,6 +154,10 @@ export default function PurchaseSuccessPage() { ); } +function formatNumber(n: number): string { + return new Intl.NumberFormat("en-US").format(n); +} + function formatCredit(micros: number): string { const dollars = micros / 1_000_000; return new Intl.NumberFormat("en-US", { diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx index 263a286c1..cf73b5eba 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent.tsx @@ -13,21 +13,25 @@ import { TableHeader, TableRow, } from "@/components/ui/table"; -import type { CreditPurchase, PagePurchase, PurchaseStatus } from "@/contracts/types/stripe.types"; +import type { + PagePurchase, + PagePurchaseStatus, + TokenPurchase, +} from "@/contracts/types/stripe.types"; import { stripeApiService } from "@/lib/apis/stripe-api.service"; import { cn } from "@/lib/utils"; -type PurchaseKind = "pages" | "credits"; +type PurchaseKind = "pages" | "tokens"; type UnifiedPurchase = { id: string; kind: PurchaseKind; created_at: string; - status: PurchaseStatus; + status: PagePurchaseStatus; /** * Granted units. Interpretation depends on ``kind``: - * - ``"pages"`` — integer number of indexed pages (legacy history). - * - ``"credits"`` — integer micro-USD of credit (1_000_000 = $1.00). + * - ``"pages"`` — integer number of indexed pages. + * - ``"tokens"`` — integer micro-USD of credit (1_000_000 = $1.00). * The ``Granted`` column formats accordingly. */ granted: number; @@ -35,7 +39,7 @@ type UnifiedPurchase = { currency: string | null; }; -const STATUS_STYLES: Record<PurchaseStatus, { label: string; className: string }> = { +const STATUS_STYLES: Record<PagePurchaseStatus, { label: string; className: string }> = { completed: { label: "Completed", className: "bg-emerald-600 text-white border-transparent hover:bg-emerald-600", @@ -59,8 +63,8 @@ const KIND_META: Record< icon: FileText, iconClass: "text-sky-500", }, - credits: { - label: "Credits", + tokens: { + label: "Premium Credit", icon: Coins, iconClass: "text-amber-500", }, @@ -93,10 +97,10 @@ function normalizePagePurchase(p: PagePurchase): UnifiedPurchase { }; } -function normalizeCreditPurchase(p: CreditPurchase): UnifiedPurchase { +function normalizeTokenPurchase(p: TokenPurchase): UnifiedPurchase { return { id: p.id, - kind: "credits", + kind: "tokens", created_at: p.created_at, status: p.status, granted: p.credit_micros_granted, @@ -106,10 +110,10 @@ function normalizeCreditPurchase(p: CreditPurchase): UnifiedPurchase { } function formatGranted(p: UnifiedPurchase): string { - if (p.kind === "credits") { + if (p.kind === "tokens") { const dollars = p.granted / 1_000_000; - // Credit packs are always whole dollars at the moment, but future - // fractional grants (refunds, partial top-ups, auto-reload) shouldn't + // Premium credit packs are always whole dollars at the moment, but + // future fractional grants (refunds, partial top-ups) shouldn't // silently round to "$0". if (dollars >= 1) return `$${dollars.toFixed(2)} of credit`; if (dollars > 0) return `$${dollars.toFixed(3)} of credit`; @@ -123,26 +127,26 @@ export function PurchaseHistoryContent() { queries: [ { queryKey: ["stripe-purchases"], - queryFn: () => stripeApiService.getPagePurchases(), + queryFn: () => stripeApiService.getPurchases(), }, { - queryKey: ["stripe-credit-purchases"], - queryFn: () => stripeApiService.getCreditPurchases(), + queryKey: ["stripe-token-purchases"], + queryFn: () => stripeApiService.getTokenPurchases(), }, ], }); - const [pagesQuery, creditsQuery] = results; - const isLoading = pagesQuery.isLoading || creditsQuery.isLoading; + const [pagesQuery, tokensQuery] = results; + const isLoading = pagesQuery.isLoading || tokensQuery.isLoading; const purchases = useMemo<UnifiedPurchase[]>(() => { const pagePurchases = pagesQuery.data?.purchases ?? []; - const creditPurchases = creditsQuery.data?.purchases ?? []; + const tokenPurchases = tokensQuery.data?.purchases ?? []; return [ ...pagePurchases.map(normalizePagePurchase), - ...creditPurchases.map(normalizeCreditPurchase), + ...tokenPurchases.map(normalizeTokenPurchase), ].sort((a, b) => new Date(b.created_at).getTime() - new Date(a.created_at).getTime()); - }, [pagesQuery.data, creditsQuery.data]); + }, [pagesQuery.data, tokensQuery.data]); if (isLoading) { return ( @@ -158,7 +162,7 @@ export function PurchaseHistoryContent() { <ReceiptText className="h-8 w-8 text-muted-foreground" /> <p className="text-sm font-medium">No purchases yet</p> <p className="text-xs text-muted-foreground"> - Your credit purchases will appear here after checkout. + Your page and premium credit purchases will appear here after checkout. </p> </div> ); diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/purchases/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/purchases/page.tsx index 55647fe29..3fa08c278 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/purchases/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/purchases/page.tsx @@ -1,11 +1,5 @@ -import { AutoReloadSettings } from "@/components/settings/auto-reload-settings"; import { PurchaseHistoryContent } from "../components/PurchaseHistoryContent"; export default function Page() { - return ( - <div className="space-y-6"> - <AutoReloadSettings /> - <PurchaseHistoryContent /> - </div> - ); + return <PurchaseHistoryContent />; } diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx index d084ac0fd..fd24600c2 100644 --- a/surfsense_web/components/assistant-ui/assistant-message.tsx +++ b/surfsense_web/components/assistant-ui/assistant-message.tsx @@ -84,7 +84,7 @@ const GenerateResumeToolUI = dynamic( ); const GeneratePodcastToolUI = dynamic( () => - import("@/components/tool-ui/podcast").then((m) => ({ + import("@/components/tool-ui/generate-podcast").then((m) => ({ default: m.GeneratePodcastToolUI, })), { ssr: false } diff --git a/surfsense_web/components/assistant-ui/chat-viewport.tsx b/surfsense_web/components/assistant-ui/chat-viewport.tsx index dedada7a5..83308b642 100644 --- a/surfsense_web/components/assistant-ui/chat-viewport.tsx +++ b/surfsense_web/components/assistant-ui/chat-viewport.tsx @@ -27,8 +27,8 @@ export interface ChatViewportProps { export const ChatViewport: FC<ChatViewportProps> = ({ children, footer }) => ( <ThreadPrimitive.Viewport turnAnchor="top" - autoScroll={false} - scrollToBottomOnRunStart={false} + autoScroll + scrollToBottomOnRunStart scrollToBottomOnInitialize scrollToBottomOnThreadSwitch className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 scroll-smooth" diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 95f118835..5796109f0 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -144,15 +144,11 @@ function getComposerSuggestionAnchorPoint( }; } -interface ThreadProps { - hasActiveThread?: boolean; -} - -export const Thread: FC<ThreadProps> = ({ hasActiveThread = false }) => { - return <ThreadContent hasActiveThread={hasActiveThread} />; +export const Thread: FC = () => { + return <ThreadContent />; }; -const ThreadContent: FC<ThreadProps> = ({ hasActiveThread = false }) => { +const ThreadContent: FC = () => { return ( <ThreadPrimitive.Root className="aui-root aui-thread-root @container flex h-full min-h-0 flex-col bg-main-panel" @@ -162,13 +158,13 @@ const ThreadContent: FC<ThreadProps> = ({ hasActiveThread = false }) => { > <ChatViewport footer={ - <AuiIf condition={({ thread }) => hasActiveThread || !thread.isEmpty}> + <AuiIf condition={({ thread }) => !thread.isEmpty}> <PremiumQuotaPinnedAlert /> <Composer /> </AuiIf> } > - <AuiIf condition={({ thread }) => !hasActiveThread && thread.isEmpty}> + <AuiIf condition={({ thread }) => thread.isEmpty}> <ThreadWelcome /> </AuiIf> diff --git a/surfsense_web/components/layout/index.ts b/surfsense_web/components/layout/index.ts index eb475e414..67f161d1a 100644 --- a/surfsense_web/components/layout/index.ts +++ b/surfsense_web/components/layout/index.ts @@ -12,7 +12,6 @@ export type { export { ChatListItem, CreateSearchSpaceDialog, - CreditBalanceDisplay, Header, IconRail, LayoutShell, @@ -20,6 +19,7 @@ export { MobileSidebarTrigger, NavIcon, NavSection, + PageUsageDisplay, SearchSpaceAvatar, Sidebar, SidebarCollapseButton, diff --git a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx index 549e6e7d7..46f6ec8ae 100644 --- a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx +++ b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx @@ -186,40 +186,40 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid setStatusInboxItems(statusInbox.inboxItems); }, [statusInbox.inboxItems, setStatusInboxItems]); - // Track seen notification IDs to detect new insufficient_credits notifications - const seenCreditNotifications = useRef<Set<number>>(new Set()); + // Track seen notification IDs to detect new page_limit_exceeded notifications + const seenPageLimitNotifications = useRef<Set<number>>(new Set()); const isInitialLoad = useRef(true); - // Effect to show toast for new insufficient_credits notifications + // Effect to show toast for new page_limit_exceeded notifications useEffect(() => { if (statusInbox.loading) return; - const creditNotifications = statusInbox.inboxItems.filter( - (item) => item.type === "insufficient_credits" + const pageLimitNotifications = statusInbox.inboxItems.filter( + (item) => item.type === "page_limit_exceeded" ); if (isInitialLoad.current) { - for (const notification of creditNotifications) { - seenCreditNotifications.current.add(notification.id); + for (const notification of pageLimitNotifications) { + seenPageLimitNotifications.current.add(notification.id); } isInitialLoad.current = false; return; } - const newNotifications = creditNotifications.filter( - (notification) => !seenCreditNotifications.current.has(notification.id) + const newNotifications = pageLimitNotifications.filter( + (notification) => !seenPageLimitNotifications.current.has(notification.id) ); for (const notification of newNotifications) { - seenCreditNotifications.current.add(notification.id); + seenPageLimitNotifications.current.add(notification.id); toast.error(notification.title, { description: notification.message, duration: 8000, icon: <AlertTriangle className="h-5 w-5 text-amber-500" />, action: { - label: "Buy credits", - onClick: () => router.push(`/dashboard/${searchSpaceId}/buy-more`), + label: "Get More Pages", + onClick: () => router.push(`/dashboard/${searchSpaceId}/more-pages`), }, }); } @@ -696,7 +696,6 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid const isAutomationsPage = pathname?.includes("/automations") === true; const useWorkspacePanel = pathname?.endsWith("/buy-more") === true || - pathname?.endsWith("/earn-credits") === true || pathname?.endsWith("/more-pages") === true || isUserSettingsPage || isSearchSpaceSettingsPage || diff --git a/surfsense_web/components/layout/types/layout.types.ts b/surfsense_web/components/layout/types/layout.types.ts index 1dfb51ca8..1bb0a089e 100644 --- a/surfsense_web/components/layout/types/layout.types.ts +++ b/surfsense_web/components/layout/types/layout.types.ts @@ -74,6 +74,11 @@ export interface ChatsSectionProps { searchSpaceId?: string; } +export interface PageUsageDisplayProps { + pagesUsed: number; + pagesLimit: number; +} + export interface SidebarUserProfileProps { user: User; searchSpaceId?: string; diff --git a/surfsense_web/components/layout/ui/index.ts b/surfsense_web/components/layout/ui/index.ts index 85a47bea1..00b862082 100644 --- a/surfsense_web/components/layout/ui/index.ts +++ b/surfsense_web/components/layout/ui/index.ts @@ -4,10 +4,10 @@ export { IconRail, NavIcon, SearchSpaceAvatar } from "./icon-rail"; export { LayoutShell } from "./shell"; export { ChatListItem, - CreditBalanceDisplay, MobileSidebar, MobileSidebarTrigger, NavSection, + PageUsageDisplay, Sidebar, SidebarCollapseButton, SidebarHeader, diff --git a/surfsense_web/components/layout/ui/sidebar/AuthenticatedPageUsageDisplay.tsx b/surfsense_web/components/layout/ui/sidebar/AuthenticatedPageUsageDisplay.tsx new file mode 100644 index 000000000..ad31d50bb --- /dev/null +++ b/surfsense_web/components/layout/ui/sidebar/AuthenticatedPageUsageDisplay.tsx @@ -0,0 +1,15 @@ +"use client"; + +import { useQuery } from "@rocicorp/zero/react"; +import { useIsAnonymous } from "@/contexts/anonymous-mode"; +import { queries } from "@/zero/queries"; +import { PageUsageDisplay } from "./PageUsageDisplay"; + +export function AuthenticatedPageUsageDisplay() { + const isAnonymous = useIsAnonymous(); + const [me] = useQuery(queries.user.me({})); + + if (isAnonymous || !me) return null; + + return <PageUsageDisplay pagesUsed={me.pagesUsed} pagesLimit={me.pagesLimit} />; +} diff --git a/surfsense_web/components/layout/ui/sidebar/CreditBalanceDisplay.tsx b/surfsense_web/components/layout/ui/sidebar/CreditBalanceDisplay.tsx deleted file mode 100644 index 1d45137fb..000000000 --- a/surfsense_web/components/layout/ui/sidebar/CreditBalanceDisplay.tsx +++ /dev/null @@ -1,55 +0,0 @@ -"use client"; - -import { useQuery } from "@rocicorp/zero/react"; -import { useIsAnonymous } from "@/contexts/anonymous-mode"; -import { cn } from "@/lib/utils"; -import { queries } from "@/zero/queries"; - -// Show the low-balance warning state once the wallet drops below $0.50. -const LOW_BALANCE_WARNING_MICROS = 500_000; - -function formatUsd(micros: number): string { - // Clamp at $0.00 — the balance can dip slightly negative when the actual - // cost of a job exceeds the pre-charge estimate. - const dollars = Math.max(0, micros) / 1_000_000; - if (dollars >= 100) return `$${dollars.toFixed(0)}`; - if (dollars >= 1) return `$${dollars.toFixed(2)}`; - // Sub-dollar balances need extra precision so the user can still tell what - // is left ("$0.042 of credit") instead of rounding to "$0.00". - if (dollars > 0) return `$${dollars.toFixed(3)}`; - return "$0.00"; -} - -/** - * Unified credit-wallet balance shown in the sidebar. - * - * The single ``creditMicrosBalance`` replaces the former page-limit and - * premium-credit meters. Values come from Zero (live-replicated from Postgres) - * as integer micro-USD (1_000_000 == $1.00). A low-balance warning highlights - * the amount when it falls below $0.50 so the user knows to top up or enable - * auto-reload. - */ -export function CreditBalanceDisplay() { - const isAnonymous = useIsAnonymous(); - const [me] = useQuery(queries.user.me({})); - - if (isAnonymous || !me) return null; - - const balanceMicros = me.creditMicrosBalance ?? 0; - const isLow = balanceMicros < LOW_BALANCE_WARNING_MICROS; - - return ( - <div className="flex items-center justify-between text-xs"> - <span className="text-muted-foreground">Credits</span> - <span - className={cn( - "font-medium tabular-nums", - isLow ? "text-amber-600 dark:text-amber-500" : "text-foreground" - )} - title={isLow ? "Low balance — buy credits or enable auto-reload" : undefined} - > - {formatUsd(balanceMicros)} - </span> - </div> - ); -} diff --git a/surfsense_web/components/layout/ui/sidebar/InboxSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/InboxSidebar.tsx index 3785dc649..f757db70e 100644 --- a/surfsense_web/components/layout/ui/sidebar/InboxSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/InboxSidebar.tsx @@ -48,8 +48,8 @@ import { isCommentReplyMetadata, isConnectorIndexingMetadata, isDocumentProcessingMetadata, - isInsufficientCreditsMetadata, isNewMentionMetadata, + isPageLimitExceededMetadata, } from "@/contracts/types/inbox.types"; import { useDebouncedValue } from "@/hooks/use-debounced-value"; import type { InboxItem } from "@/hooks/use-inbox"; @@ -291,7 +291,7 @@ export function InboxSidebarContent({ (item: InboxItem): boolean => { if (activeFilter === "unread") return !item.read; if (activeFilter === "errors") { - if (item.type === "insufficient_credits") return true; + if (item.type === "page_limit_exceeded") return true; const meta = item.metadata as Record<string, unknown> | undefined; return typeof meta?.status === "string" && meta.status === "failed"; } @@ -397,8 +397,8 @@ export function InboxSidebarContent({ router.push(url); } } - } else if (item.type === "insufficient_credits") { - if (isInsufficientCreditsMetadata(item.metadata)) { + } else if (item.type === "page_limit_exceeded") { + if (isPageLimitExceededMetadata(item.metadata)) { const actionUrl = item.metadata.action_url; if (actionUrl) { onOpenChange(false); @@ -470,7 +470,7 @@ export function InboxSidebarContent({ ); } - if (item.type === "insufficient_credits") { + if (item.type === "page_limit_exceeded") { return ( <div className="h-8 w-8 flex items-center justify-center rounded-full bg-amber-500/10"> <AlertTriangle className="h-4 w-4 text-amber-500" /> diff --git a/surfsense_web/components/layout/ui/sidebar/PageUsageDisplay.tsx b/surfsense_web/components/layout/ui/sidebar/PageUsageDisplay.tsx new file mode 100644 index 000000000..3d011b762 --- /dev/null +++ b/surfsense_web/components/layout/ui/sidebar/PageUsageDisplay.tsx @@ -0,0 +1,24 @@ +"use client"; + +import { Progress } from "@/components/ui/progress"; + +interface PageUsageDisplayProps { + pagesUsed: number; + pagesLimit: number; +} + +export function PageUsageDisplay({ pagesUsed, pagesLimit }: PageUsageDisplayProps) { + const usagePercentage = (pagesUsed / pagesLimit) * 100; + + return ( + <div className="space-y-1.5"> + <div className="flex justify-between items-center text-xs"> + <span className="text-muted-foreground"> + {pagesUsed.toLocaleString()} / {pagesLimit.toLocaleString()} pages + </span> + <span className="font-medium">{usagePercentage.toFixed(0)}%</span> + </div> + <Progress value={usagePercentage} className="h-1.5" /> + </div> + ); +} diff --git a/surfsense_web/components/layout/ui/sidebar/PremiumTokenUsageDisplay.tsx b/surfsense_web/components/layout/ui/sidebar/PremiumTokenUsageDisplay.tsx new file mode 100644 index 000000000..983672d0b --- /dev/null +++ b/surfsense_web/components/layout/ui/sidebar/PremiumTokenUsageDisplay.tsx @@ -0,0 +1,49 @@ +"use client"; + +import { useQuery } from "@rocicorp/zero/react"; +import { Progress } from "@/components/ui/progress"; +import { useIsAnonymous } from "@/contexts/anonymous-mode"; +import { queries } from "@/zero/queries"; + +/** + * Premium credit balance shown in the sidebar. + * + * Values come from Zero (live-replicated from Postgres) and are stored as + * integer micro-USD (1_000_000 == $1.00). We render in dollars because + * users top up at $1/pack and the credit gets debited at actual provider + * cost. + */ +export function PremiumTokenUsageDisplay() { + const isAnonymous = useIsAnonymous(); + const [me] = useQuery(queries.user.me({})); + + if (isAnonymous || !me) return null; + + const usagePercentage = Math.min( + (me.premiumCreditMicrosUsed / Math.max(me.premiumCreditMicrosLimit, 1)) * 100, + 100 + ); + + const formatUsd = (micros: number) => { + const dollars = micros / 1_000_000; + if (dollars >= 100) return `$${dollars.toFixed(0)}`; + if (dollars >= 1) return `$${dollars.toFixed(2)}`; + // Sub-dollar balances need extra precision so the bar still tells the + // user what's left ("$0.04 of credit") instead of rounding to "$0". + if (dollars > 0) return `$${dollars.toFixed(3)}`; + return "$0"; + }; + + return ( + <div className="space-y-1.5"> + <div className="flex justify-between items-center text-xs"> + <span className="text-muted-foreground"> + {formatUsd(me.premiumCreditMicrosUsed)} / {formatUsd(me.premiumCreditMicrosLimit)} of + credit + </span> + <span className="font-medium">{usagePercentage.toFixed(0)}%</span> + </div> + <Progress value={usagePercentage} className="h-1.5 [&>div]:bg-purple-500" /> + </div> + ); +} diff --git a/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx b/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx index ee891d78b..6a4785d98 100644 --- a/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx @@ -1,6 +1,6 @@ "use client"; -import { CreditCard, SquarePen, Zap } from "lucide-react"; +import { CreditCard, Dot, SquarePen, Zap } from "lucide-react"; import Link from "next/link"; import { useParams } from "next/navigation"; import { useTranslations } from "next-intl"; @@ -13,9 +13,10 @@ import { useIsAnonymous } from "@/contexts/anonymous-mode"; import { cn } from "@/lib/utils"; import { SIDEBAR_MIN_WIDTH } from "../../hooks/useSidebarResize"; import type { ChatItem, NavItem, PageUsage, SearchSpace, User } from "../../types/layout.types"; +import { AuthenticatedPageUsageDisplay } from "./AuthenticatedPageUsageDisplay"; import { ChatListItem } from "./ChatListItem"; -import { CreditBalanceDisplay } from "./CreditBalanceDisplay"; import { NavSection } from "./NavSection"; +import { PremiumTokenUsageDisplay } from "./PremiumTokenUsageDisplay"; import { SidebarButton } from "./SidebarButton"; import { SidebarCollapseButton } from "./SidebarCollapseButton"; import { SidebarHeader } from "./SidebarHeader"; @@ -403,16 +404,17 @@ function SidebarUsageFooter({ return ( <div className={containerClass}> - <CreditBalanceDisplay /> + <PremiumTokenUsageDisplay /> + <AuthenticatedPageUsageDisplay /> <div className="space-y-0.5"> <Link - href={`/dashboard/${searchSpaceId}/earn-credits`} + href={`/dashboard/${searchSpaceId}/more-pages`} onClick={onNavigate} className="group flex w-full items-center justify-between rounded-md px-1.5 py-1 transition-colors hover:bg-accent" > <span className="flex items-center gap-1.5 text-xs text-muted-foreground group-hover:text-accent-foreground"> <Zap className="h-3 w-3 shrink-0" /> - Earn credits + Get Free Pages </span> <Badge className="h-4 rounded px-1 text-[10px] font-semibold leading-none bg-emerald-600 text-white border-transparent hover:bg-emerald-600"> FREE @@ -425,7 +427,12 @@ function SidebarUsageFooter({ > <span className="flex items-center gap-1.5 text-xs text-muted-foreground group-hover:text-accent-foreground"> <CreditCard className="h-3 w-3 shrink-0" /> - Buy credits + Buy More + </span> + <span className="flex items-center text-[10px] font-medium text-muted-foreground"> + $1/1k + <Dot className="h-3 w-3" /> + $1/1M </span> </Link> </div> diff --git a/surfsense_web/components/layout/ui/sidebar/index.ts b/surfsense_web/components/layout/ui/sidebar/index.ts index fcfe2252d..e25149b06 100644 --- a/surfsense_web/components/layout/ui/sidebar/index.ts +++ b/surfsense_web/components/layout/ui/sidebar/index.ts @@ -1,10 +1,10 @@ export { AllChatsSidebar, AllChatsSidebarContent } from "./AllChatsSidebar"; export { ChatListItem } from "./ChatListItem"; -export { CreditBalanceDisplay } from "./CreditBalanceDisplay"; export { DocumentsSidebar } from "./DocumentsSidebar"; export { InboxSidebar, InboxSidebarContent } from "./InboxSidebar"; export { MobileSidebar, MobileSidebarTrigger } from "./MobileSidebar"; export { NavSection } from "./NavSection"; +export { PageUsageDisplay } from "./PageUsageDisplay"; export { Sidebar } from "./Sidebar"; export { SidebarCollapseButton } from "./SidebarCollapseButton"; export { SidebarHeader } from "./SidebarHeader"; diff --git a/surfsense_web/components/new-chat/chat-example-prompts.tsx b/surfsense_web/components/new-chat/chat-example-prompts.tsx index 61041cc29..98d95b98b 100644 --- a/surfsense_web/components/new-chat/chat-example-prompts.tsx +++ b/surfsense_web/components/new-chat/chat-example-prompts.tsx @@ -2,9 +2,9 @@ import { FilePlus2, - type LucideIcon, Search, Settings2, + type LucideIcon, WandSparkles, Workflow, X, diff --git a/surfsense_web/components/pricing/pricing-section.tsx b/surfsense_web/components/pricing/pricing-section.tsx index 1e11e95d5..46ceee694 100644 --- a/surfsense_web/components/pricing/pricing-section.tsx +++ b/surfsense_web/components/pricing/pricing-section.tsx @@ -14,11 +14,11 @@ const demoPlans = [ price: "0", yearlyPrice: "0", period: "", - billingText: "$5 of credit included to start", + billingText: "500 pages + $5 in premium credits included", features: [ "Self Hostable", - "$5 of credit included to start", - "One credit balance for document processing and premium AI features", + "500 pages included to start", + "$5 in premium credits for paid AI models and premium AI features", "Includes access to OpenAI text, audio and image models", "AI automations and agents: scheduled and event-triggered workflows", "Desktop app: Quick, General and Screenshot Assist plus local folder sync", @@ -38,7 +38,7 @@ const demoPlans = [ billingText: "No subscription, buy only when you need more", features: [ "Everything in Free", - "Buy credit in $1 packs — $1 buys $1 of credit, with optional auto-reload", + "Buy 1,000-page packs or $1 in premium credits at $1 each", "Use premium AI models like GPT-5.4, Claude Sonnet 4.6, Gemini 2.5 Pro & 100+ more via OpenRouter", "Connector write-back to Notion, Slack, Linear & Jira", "Priority support on Discord", @@ -84,32 +84,32 @@ interface FAQSection { const faqData: FAQSection[] = [ { - title: "Credits & Document Billing", + title: "Pages & Document Billing", items: [ { - question: "What are credits in SurfSense?", + question: 'What exactly is a "page" in SurfSense?', answer: - "Credits are a single prepaid balance shown in dollars that powers everything in SurfSense — both document processing and premium AI features. New accounts start with $5 of credit. Your balance goes down as you use the product and back up when you top up or earn more, so there's just one number to keep an eye on.", + "A page is a simple billing unit that measures how much content you add to your knowledge base. For PDFs, one page equals one real PDF page. For other document types like Word, PowerPoint, and Excel files, pages are automatically estimated based on the file. Every file uses at least 1 page.", }, { - question: "How much does document processing cost?", + question: "What are Basic and Premium processing modes?", answer: - "Document processing is billed per page out of your credit balance. For PDFs, one page equals one real PDF page; for other document types like Word, PowerPoint, and Excel files, pages are automatically estimated. Basic mode costs $0.001 per page and Premium mode costs $0.01 per page. Premium processing uses advanced extraction optimized for complex financial, medical, and legal documents with intricate tables and layouts. Every file uses at least 1 page.", + "When uploading documents, you can choose between two processing modes. Basic mode uses standard extraction and costs 1 page credit per page, great for most documents. Premium processing mode uses advanced extraction optimized for complex financial, medical, and legal documents with intricate tables, layouts, and formatting. It costs 10 page credits per page and does not use your premium AI credits.", }, { question: "How does the Pay As You Go plan work?", answer: - "There's no monthly subscription. When you need more credit, simply buy $1 packs — $1 buys exactly $1 of credit. Purchased credit is added to your balance immediately so you can keep working right away. You only pay when you actually need more, and you can enable auto-reload to top up automatically.", + "There's no monthly subscription. When you need more pages, simply purchase 1,000-page packs at $1 each. Purchased pages are added to your account immediately so you can keep indexing right away. You only pay when you actually need more.", }, { - question: "What happens if I run out of credit?", + question: "What happens if I run out of pages?", answer: - "SurfSense checks your remaining credit before processing each file. If you don't have enough, the upload is paused and you'll be notified so you can buy more credit and continue. For cloud connector syncs, a small overage may be allowed so your sync doesn't partially fail.", + "SurfSense checks your remaining pages before processing each file. If you don't have enough, the upload is paused and you'll be notified. You can purchase additional page packs at any time to continue. For cloud connector syncs, a small overage may be allowed so your sync doesn't partially fail.", }, { - question: "If I delete a document, do I get my credit back?", + question: "If I delete a document, do I get my pages back?", answer: - "No. Deleting a document removes it from your knowledge base, but the credit it used is not refunded. Credit tracks your total usage over time, not how much is currently stored, so be mindful of what you index. Once credit is spent, it's spent even if you later remove the document.", + "No. Deleting a document removes it from your knowledge base, but the pages it used are not refunded. Pages track your total usage over time, not how much is currently stored. So be mindful of what you index. Once pages are spent, they're spent even if you later remove the document.", }, ], }, @@ -117,49 +117,49 @@ const faqData: FAQSection[] = [ title: "File Types & Connectors", items: [ { - question: "Which file types use credit?", + question: "Which file types count toward my page limit?", answer: - "Credit is only used for document files that need processing, including PDFs, Word documents (DOC, DOCX, ODT, RTF), presentations (PPT, PPTX, ODP), spreadsheets (XLS, XLSX, ODS), ebooks (EPUB), and images (JPG, PNG, TIFF, WebP, BMP). Plain text files, code files, Markdown, CSV, TSV, HTML, audio, and video files do not consume any credit.", + "Page limits only apply to document files that need processing, including PDFs, Word documents (DOC, DOCX, ODT, RTF), presentations (PPT, PPTX, ODP), spreadsheets (XLS, XLSX, ODS), ebooks (EPUB), and images (JPG, PNG, TIFF, WebP, BMP). Plain text files, code files, Markdown, CSV, TSV, HTML, audio, and video files do not consume any pages.", }, { - question: "How is credit consumed for documents?", + question: "How are pages consumed?", answer: - "Credit is deducted whenever a document file is successfully indexed into your knowledge base, whether through direct uploads or file-based connector syncs (Google Drive, OneDrive, Dropbox, Local Folder). In Basic mode each page costs $0.001; in Premium mode each page costs $0.01. SurfSense checks your remaining credit before processing and only charges you after the file is indexed. Duplicate documents are automatically detected and won't cost you extra.", + "Pages are deducted whenever a document file is successfully indexed into your knowledge base, whether through direct uploads or file-based connector syncs (Google Drive, OneDrive, Dropbox, Local Folder). In Basic mode, each page costs 1 page credit; in Premium mode, each page costs 10 page credits. SurfSense checks your remaining credits before processing and only charges you after the file is indexed. Duplicate documents are automatically detected and won't cost you extra pages.", }, { - question: "Do connectors like Slack, Notion, or Gmail use credit?", + question: "Do connectors like Slack, Notion, or Gmail use pages?", answer: - "No. Connectors that work with structured text data like Slack, Discord, Notion, Confluence, Jira, Linear, ClickUp, GitHub, Gmail, Google Calendar, Microsoft Teams, Airtable, Elasticsearch, Web Crawler, BookStack, Obsidian, and Luma do not use credit at all. Document-processing charges only apply to file-based connectors such as Google Drive, OneDrive, Dropbox, and Local Folder syncs.", + "No. Connectors that work with structured text data like Slack, Discord, Notion, Confluence, Jira, Linear, ClickUp, GitHub, Gmail, Google Calendar, Microsoft Teams, Airtable, Elasticsearch, Web Crawler, BookStack, Obsidian, and Luma do not use pages at all. Page limits only apply to file-based connectors that need document processing, such as Google Drive, OneDrive, Dropbox, and Local Folder syncs.", }, ], }, { - title: "Premium AI & Credit", + title: "Premium Credits", items: [ { - question: "How is credit used for premium AI?", + question: 'What are "premium credits"?', answer: - "The same credit balance pays for paid AI usage in SurfSense, including premium AI models like GPT-5.4, Claude Sonnet 4.6, and Gemini 2.5 Pro, plus premium AI features such as image generation, podcasts, and video presentations when they use paid models. Each request debits the actual USD provider cost, so cheaper and more expensive models bill proportionally.", + "Premium credits are your USD balance for paid AI usage in SurfSense, including premium AI models like GPT-5.4, Claude Sonnet 4.6, and Gemini 2.5 Pro, plus premium AI features such as image generation, podcasts, and video presentations when they use paid models. Each request debits the actual USD provider cost, so cheaper and more expensive models bill proportionally.", }, { - question: "How much credit do I get for free?", + question: "How many premium credits do I get for free?", answer: - "Every registered SurfSense account starts with $5 of credit at no cost. Anonymous users (no login) get 500,000 free tokens across free models before creating an account. Once your included credit runs out, you can top up at any time or earn more by completing tasks.", + "Every registered SurfSense account starts with $5 in premium credits at no cost. Anonymous users (no login) get 500,000 free tokens across free models before creating an account. Once your included premium credits run out, you can top up at any time.", }, { - question: "How does buying credit work?", + question: "How does buying premium credits work?", answer: - "Top-ups are pay as you go, with no subscription. $1 buys $1 of credit, and your balance is spent at provider cost. Purchased credit is added to your account immediately, and you can buy any custom amount. Enable auto-reload to top up automatically when your balance runs low.", + "Premium credit top-ups are pay as you go, with no subscription. $1 buys $1 of credit, and your balance is spent at provider cost. Purchased credit is added to your account immediately. You can buy up to $100 at a time.", }, { - question: "Is there a separate balance for documents and AI?", + question: "Are premium credits the same as page credits?", answer: - "No. SurfSense uses one unified credit balance for everything — document indexing, file-based connector processing, premium model chats, and premium AI generation features all draw from the same wallet. Premium document processing mode simply costs more per page ($0.01 vs $0.001), but it's the same credit.", + "No. Page credits pay for document indexing and file-based connector processing. Premium credits pay for paid AI usage, such as premium model chats and premium AI generation features. Premium document processing mode sounds similar, but it consumes page credits, not premium credits.", }, { - question: "What happens if I run out of credit?", + question: "What happens if I run out of premium credits?", answer: - "When your credit balance runs low, you'll see a warning. Once you run out, paid model requests, premium AI features, and document processing are paused until you top up. You can still use non-premium models and features that do not consume credit.", + "When your premium credit balance runs low, you'll see a warning. Once you run out, paid model requests and premium AI features are paused until you top up. You can still use non-premium models and features that do not consume premium credits.", }, ], }, @@ -174,7 +174,7 @@ const faqData: FAQSection[] = [ { question: "Do automations and agents cost extra?", answer: - "No. There is no separate subscription or add-on fee for automations. Agents draw from the same unified credit balance as the rest of SurfSense. Indexing documents and premium AI model usage during a workflow both consume credit at provider cost. If a workflow only uses free models and indexes no documents, it does not touch your credit.", + "No. There is no separate subscription or add-on fee for automations. Agents use the same page credits and premium credits as the rest of SurfSense. Indexing documents consumes page credits, and premium AI model usage during a workflow consumes premium credits at provider cost. If a workflow only uses free models, it does not touch your premium credits.", }, { question: "How do event-triggered automations work?", @@ -192,9 +192,9 @@ const faqData: FAQSection[] = [ title: "Self-Hosting", items: [ { - question: "Can I self-host SurfSense with unlimited usage?", + question: "Can I self-host SurfSense with unlimited pages and credit?", answer: - "Yes! When self-hosting, you have full control over billing. The default self-hosted setup leaves document-processing credit billing off and gives you effectively unlimited credit, so you can index as much data and use as many AI queries as your infrastructure supports.", + "Yes! When self-hosting, you have full control over your page and premium credit limits. The default self-hosted setup gives you effectively unlimited pages and premium credits, so you can index as much data and use as many AI queries as your infrastructure supports.", }, ], }, @@ -286,8 +286,8 @@ function PricingFAQ() { Frequently Asked Questions </h2> <p className="mx-auto mt-4 max-w-2xl text-lg text-muted-foreground"> - Everything you need to know about SurfSense credits and billing. Can't find what you - need? Reach out at{" "} + Everything you need to know about SurfSense pages, premium credits, and billing. + Can't find what you need? Reach out at{" "} <a href="mailto:rohan@surfsense.com" className="text-blue-500 underline"> rohan@surfsense.com </a> @@ -372,7 +372,7 @@ function PricingBasic() { <Pricing plans={demoPlans} title="SurfSense Pricing" - description="Start free with $5 of credit. Run AI automations and agents, and pay as you go." + description="Start free with 500 pages & $5 in premium credits. Run AI automations and agents, and pay as you go." /> <PricingFAQ /> </> diff --git a/surfsense_web/components/public-chat/public-thread.tsx b/surfsense_web/components/public-chat/public-thread.tsx index 083cc5e35..d35193cbe 100644 --- a/surfsense_web/components/public-chat/public-thread.tsx +++ b/surfsense_web/components/public-chat/public-thread.tsx @@ -17,9 +17,9 @@ import { MarkdownText } from "@/components/assistant-ui/markdown-text"; import { ReasoningMessagePart } from "@/components/assistant-ui/reasoning-message-part"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { GenerateImageToolUI } from "@/components/tool-ui/generate-image"; +import { GeneratePodcastToolUI } from "@/components/tool-ui/generate-podcast"; import { GenerateReportToolUI } from "@/components/tool-ui/generate-report"; import { GenerateResumeToolUI } from "@/components/tool-ui/generate-resume"; -import { GeneratePodcastToolUI } from "@/components/tool-ui/podcast"; const GenerateVideoPresentationToolUI = dynamic( () => diff --git a/surfsense_web/components/settings/auto-reload-settings.tsx b/surfsense_web/components/settings/auto-reload-settings.tsx deleted file mode 100644 index fbb7cbfb9..000000000 --- a/surfsense_web/components/settings/auto-reload-settings.tsx +++ /dev/null @@ -1,276 +0,0 @@ -"use client"; - -import { useQuery as useZeroQuery } from "@rocicorp/zero/react"; -import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; -import { AlertTriangle, CreditCard, RefreshCw } from "lucide-react"; -import { useParams, usePathname, useRouter, useSearchParams } from "next/navigation"; -import { useEffect, useRef, useState } from "react"; -import { toast } from "sonner"; -import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; -import { Button } from "@/components/ui/button"; -import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { Spinner } from "@/components/ui/spinner"; -import { Switch } from "@/components/ui/switch"; -import { stripeApiService } from "@/lib/apis/stripe-api.service"; -import { AppError } from "@/lib/error"; -import { queries } from "@/zero/queries"; - -const microsToDollars = (micros: number | null | undefined): string => { - if (micros == null) return ""; - return (micros / 1_000_000).toString(); -}; - -const dollarsToMicros = (value: string): number | null => { - const trimmed = value.trim(); - if (trimmed === "") return null; - const dollars = Number(trimmed); - if (!Number.isFinite(dollars) || dollars < 0) return null; - return Math.round(dollars * 1_000_000); -}; - -const formatUsd = (micros: number) => `$${(Math.max(0, micros) / 1_000_000).toFixed(2)}`; - -export function AutoReloadSettings() { - const params = useParams(); - const router = useRouter(); - const pathname = usePathname(); - const searchParams = useSearchParams(); - const queryClient = useQueryClient(); - const searchSpaceId = Number(params?.search_space_id); - - const [enabled, setEnabled] = useState(false); - const [thresholdInput, setThresholdInput] = useState(""); - const [amountInput, setAmountInput] = useState(""); - const seededRef = useRef(false); - - const [me] = useZeroQuery(queries.user.me({})); - const balanceMicros = me?.creditMicrosBalance ?? 0; - - const { data: settings, isLoading } = useQuery({ - queryKey: ["auto-reload-settings"], - queryFn: () => stripeApiService.getAutoReloadSettings(), - }); - - // Seed the form once from the server, then let the user own the inputs. - useEffect(() => { - if (settings && !seededRef.current) { - seededRef.current = true; - setEnabled(settings.enabled); - setThresholdInput(microsToDollars(settings.threshold_micros)); - setAmountInput(microsToDollars(settings.amount_micros)); - } - }, [settings]); - - // Surface the result of the Stripe card-setup redirect. - useEffect(() => { - const setupResult = searchParams.get("auto_reload_setup"); - if (!setupResult) return; - if (setupResult === "success") { - toast.success("Card saved. You can now enable auto-reload."); - queryClient.invalidateQueries({ queryKey: ["auto-reload-settings"] }); - } else if (setupResult === "cancel") { - toast.info("Card setup canceled."); - } - // Strip the query param so refreshes don't re-toast. - router.replace(pathname); - }, [searchParams, router, pathname, queryClient]); - - const setupMutation = useMutation({ - mutationFn: () => - stripeApiService.createAutoReloadSetupSession({ search_space_id: searchSpaceId }), - onSuccess: (response) => { - window.location.assign(response.checkout_url); - }, - onError: () => { - toast.error("Couldn't start card setup. Please try again."); - }, - }); - - const saveMutation = useMutation({ - mutationFn: stripeApiService.updateAutoReloadSettings, - onSuccess: (updated) => { - queryClient.setQueryData(["auto-reload-settings"], updated); - toast.success(updated.enabled ? "Auto-reload is on." : "Auto-reload settings saved."); - }, - onError: (error) => { - if (error instanceof AppError && error.message) { - toast.error(error.message); - return; - } - toast.error("Couldn't save auto-reload settings. Please try again."); - }, - }); - - // Render nothing while loading (avoids a spinner flash on pages where the - // feature flag turns out to be off) and when auto-reload is disabled - // server-side. - if (isLoading || !settings || !settings.feature_enabled) { - return null; - } - - const minAmountDollars = (settings.min_amount_micros / 1_000_000).toFixed(2); - const hasCard = settings.has_payment_method; - - const handleSave = () => { - if (!enabled) { - saveMutation.mutate({ - enabled: false, - threshold_micros: dollarsToMicros(thresholdInput), - amount_micros: dollarsToMicros(amountInput), - }); - return; - } - - const thresholdMicros = dollarsToMicros(thresholdInput); - const amountMicros = dollarsToMicros(amountInput); - - if (!thresholdMicros || thresholdMicros <= 0) { - toast.error("Enter a low-balance threshold greater than $0."); - return; - } - if (amountMicros == null || amountMicros < settings.min_amount_micros) { - toast.error(`Reload amount must be at least $${minAmountDollars}.`); - return; - } - - saveMutation.mutate({ - enabled: true, - threshold_micros: thresholdMicros, - amount_micros: amountMicros, - }); - }; - - return ( - <Card> - <CardHeader> - <CardTitle className="flex items-center gap-2 text-base"> - <RefreshCw className="h-4 w-4 text-amber-500" /> - Auto-reload - </CardTitle> - <CardDescription> - Automatically top up your credit balance when it drops below a threshold, using a saved - card. Current balance:{" "} - <span className="font-medium text-foreground">{formatUsd(balanceMicros)}</span>. - </CardDescription> - </CardHeader> - <CardContent className="space-y-5"> - {settings.failed_at && ( - <Alert variant="destructive"> - <AlertTriangle className="h-4 w-4" /> - <AlertTitle>Last auto-reload failed</AlertTitle> - <AlertDescription> - Your saved card was declined and auto-reload was turned off. Update your card and - re-enable it below to keep topping up automatically. - </AlertDescription> - </Alert> - )} - - {!hasCard ? ( - <div className="flex flex-col items-start gap-3 rounded-lg border bg-muted/20 p-4"> - <div className="flex items-center gap-2 text-sm"> - <CreditCard className="h-4 w-4 text-muted-foreground" /> - <span>Add a card to enable automatic top-ups.</span> - </div> - <Button onClick={() => setupMutation.mutate()} disabled={setupMutation.isPending}> - {setupMutation.isPending ? ( - <> - <Spinner size="xs" /> - Redirecting - </> - ) : ( - "Add a card" - )} - </Button> - </div> - ) : ( - <> - <div className="flex items-center justify-between gap-4"> - <div className="space-y-0.5"> - <Label htmlFor="auto-reload-toggle" className="text-sm font-medium"> - Enable auto-reload - </Label> - <p className="text-xs text-muted-foreground"> - Charge your saved card when the balance gets low. - </p> - </div> - <Switch id="auto-reload-toggle" checked={enabled} onCheckedChange={setEnabled} /> - </div> - - <div className="grid gap-4 sm:grid-cols-2"> - <div className="space-y-1.5"> - <Label htmlFor="auto-reload-threshold" className="text-xs"> - When balance falls below - </Label> - <div className="relative"> - <span className="pointer-events-none absolute left-3 top-1/2 -translate-y-1/2 text-sm text-muted-foreground"> - $ - </span> - <Input - id="auto-reload-threshold" - type="number" - min="0" - step="1" - inputMode="decimal" - className="pl-6 tabular-nums" - value={thresholdInput} - onChange={(e) => setThresholdInput(e.target.value)} - disabled={!enabled} - placeholder="5" - /> - </div> - </div> - <div className="space-y-1.5"> - <Label htmlFor="auto-reload-amount" className="text-xs"> - Add this much credit - </Label> - <div className="relative"> - <span className="pointer-events-none absolute left-3 top-1/2 -translate-y-1/2 text-sm text-muted-foreground"> - $ - </span> - <Input - id="auto-reload-amount" - type="number" - min={minAmountDollars} - step="1" - inputMode="decimal" - className="pl-6 tabular-nums" - value={amountInput} - onChange={(e) => setAmountInput(e.target.value)} - disabled={!enabled} - placeholder="10" - /> - </div> - <p className="text-[11px] text-muted-foreground">Minimum ${minAmountDollars}.</p> - </div> - </div> - - <div className="flex items-center justify-between gap-3"> - <Button - variant="ghost" - size="sm" - className="text-muted-foreground" - onClick={() => setupMutation.mutate()} - disabled={setupMutation.isPending} - > - <CreditCard className="h-3.5 w-3.5" /> - Update card - </Button> - <Button onClick={handleSave} disabled={saveMutation.isPending}> - {saveMutation.isPending ? ( - <> - <Spinner size="xs" /> - Saving - </> - ) : ( - "Save" - )} - </Button> - </div> - </> - )} - </CardContent> - </Card> - ); -} diff --git a/surfsense_web/components/settings/buy-pages-content.tsx b/surfsense_web/components/settings/buy-pages-content.tsx new file mode 100644 index 000000000..82b8d8e2a --- /dev/null +++ b/surfsense_web/components/settings/buy-pages-content.tsx @@ -0,0 +1,148 @@ +"use client"; + +import { useMutation, useQuery } from "@tanstack/react-query"; +import { Minus, Plus } from "lucide-react"; +import { useParams } from "next/navigation"; +import { useState } from "react"; +import { toast } from "sonner"; +import { Button } from "@/components/ui/button"; +import { Spinner } from "@/components/ui/spinner"; +import { stripeApiService } from "@/lib/apis/stripe-api.service"; +import { AppError } from "@/lib/error"; +import { cn } from "@/lib/utils"; + +const PAGE_PACK_SIZE = 1000; +const PRICE_PER_PACK_USD = 1; +const PRESET_MULTIPLIERS = [1, 2, 5, 10, 25, 50] as const; + +export function BuyPagesContent() { + const params = useParams(); + const [quantity, setQuantity] = useState(1); + const { data: stripeStatus } = useQuery({ + queryKey: ["stripe-status"], + queryFn: () => stripeApiService.getStatus(), + }); + + const purchaseMutation = useMutation({ + mutationFn: stripeApiService.createCheckoutSession, + onSuccess: (response) => { + window.location.assign(response.checkout_url); + }, + onError: (error) => { + if (error instanceof AppError && error.message) { + toast.error(error.message); + return; + } + toast.error("Failed to start checkout. Please try again."); + }, + }); + + const searchSpaceId = Number(params.search_space_id); + const hasValidSearchSpace = Number.isFinite(searchSpaceId) && searchSpaceId > 0; + const totalPages = quantity * PAGE_PACK_SIZE; + const totalPrice = quantity * PRICE_PER_PACK_USD; + + if (stripeStatus && !stripeStatus.page_buying_enabled) { + return ( + <div className="w-full space-y-3 text-center"> + <h2 className="text-xl font-bold tracking-tight">Buy Pages</h2> + <p className="text-sm text-muted-foreground">Page purchases are temporarily unavailable.</p> + </div> + ); + } + + const handleBuyNow = () => { + if (!hasValidSearchSpace) { + toast.error("Unable to determine the current workspace for checkout."); + return; + } + purchaseMutation.mutate({ + quantity, + search_space_id: searchSpaceId, + }); + }; + + return ( + <div className="w-full space-y-5"> + <div className="text-center"> + <h2 className="text-xl font-bold tracking-tight">Buy Pages</h2> + <p className="mt-1 text-sm text-muted-foreground">$1 per 1,000 pages, pay as you go</p> + </div> + + <div className="space-y-3"> + {/* Stepper */} + <div className="flex items-center justify-center gap-3"> + <Button + type="button" + variant="ghost" + size="icon" + onClick={() => setQuantity((q) => Math.max(1, q - 1))} + disabled={quantity <= 1 || purchaseMutation.isPending} + className="size-8 text-muted-foreground shadow-none transition-colors hover:bg-muted hover:text-white disabled:opacity-40" + > + <Minus className="h-3.5 w-3.5" /> + </Button> + <span className="min-w-28 text-center text-lg font-semibold tabular-nums"> + {totalPages.toLocaleString()} + </span> + <Button + type="button" + variant="ghost" + size="icon" + onClick={() => setQuantity((q) => Math.min(100, q + 1))} + disabled={quantity >= 100 || purchaseMutation.isPending} + className="size-8 text-muted-foreground shadow-none transition-colors hover:bg-muted hover:text-white disabled:opacity-40" + > + <Plus className="h-3.5 w-3.5" /> + </Button> + </div> + + {/* Quick-pick presets */} + <div className="flex flex-wrap justify-center gap-1.5"> + {PRESET_MULTIPLIERS.map((m) => ( + <Button + key={m} + type="button" + variant="ghost" + onClick={() => setQuantity(m)} + disabled={purchaseMutation.isPending} + className={cn( + "h-auto rounded-md px-2.5 py-1 text-xs font-medium tabular-nums transition-colors disabled:opacity-60", + quantity === m + ? "bg-accent text-accent-foreground" + : "text-muted-foreground hover:bg-accent hover:text-accent-foreground" + )} + > + {(m * PAGE_PACK_SIZE).toLocaleString()} + </Button> + ))} + </div> + + <div className="flex items-center justify-between rounded-lg border bg-muted/30 px-3 py-2"> + <span className="text-sm font-medium tabular-nums"> + {totalPages.toLocaleString()} pages + </span> + <span className="text-sm font-semibold tabular-nums">${totalPrice}</span> + </div> + + <Button + className="w-full" + disabled={purchaseMutation.isPending || !hasValidSearchSpace} + onClick={handleBuyNow} + > + {purchaseMutation.isPending ? ( + <> + <Spinner size="xs" /> + Redirecting + </> + ) : ( + <> + Buy {totalPages.toLocaleString()} Pages for ${totalPrice} + </> + )} + </Button> + <p className="text-center text-[11px] text-muted-foreground">Secure checkout via Stripe</p> + </div> + </div> + ); +} diff --git a/surfsense_web/components/settings/buy-credits-content.tsx b/surfsense_web/components/settings/buy-tokens-content.tsx similarity index 57% rename from surfsense_web/components/settings/buy-credits-content.tsx rename to surfsense_web/components/settings/buy-tokens-content.tsx index 8cb339420..4b0605f28 100644 --- a/surfsense_web/components/settings/buy-credits-content.tsx +++ b/surfsense_web/components/settings/buy-tokens-content.tsx @@ -7,59 +7,46 @@ import { useParams } from "next/navigation"; import { useState } from "react"; import { toast } from "sonner"; import { Button } from "@/components/ui/button"; +import { Progress } from "@/components/ui/progress"; import { Spinner } from "@/components/ui/spinner"; import { stripeApiService } from "@/lib/apis/stripe-api.service"; import { AppError } from "@/lib/error"; import { cn } from "@/lib/utils"; import { queries } from "@/zero/queries"; -// One pack = $1.00 of credit, stored as 1_000_000 micro-USD on the backend. -// ETL page processing and premium turns are both debited from the same wallet -// at the actual cost, so $1 of credit always buys $1 of usage at cost. +// One pack = $1.00 of credit, stored as 1_000_000 micro-USD on the +// backend. Premium turns are debited at the actual provider cost +// reported by LiteLLM, so $1 of credit always buys $1 of provider +// usage at cost. const CREDIT_PER_PACK_MICROS = 1_000_000; const PRICE_PER_PACK_USD = 1; -const PRESET_MULTIPLIERS = [1, 2, 5, 10, 25, 50, 100] as const; -const MIN_QUANTITY = 1; -const MAX_QUANTITY = 10_000; +const PRESET_MULTIPLIERS = [1, 2, 5, 10, 25, 50] as const; -const clampQuantity = (value: number) => - Math.min(MAX_QUANTITY, Math.max(MIN_QUANTITY, Math.floor(value))); - -const formatUsd = (micros: number) => { - // Clamp at $0.00 — the balance can dip slightly negative when actual cost - // exceeds the pre-charge estimate. - const dollars = Math.max(0, micros) / 1_000_000; +const formatUsd = (micros: number, options?: { compact?: boolean }) => { + const dollars = micros / 1_000_000; + if (options?.compact && dollars >= 1) return `$${dollars.toFixed(2)}`; if (dollars >= 100) return `$${dollars.toFixed(0)}`; if (dollars >= 1) return `$${dollars.toFixed(2)}`; if (dollars > 0) return `$${dollars.toFixed(3)}`; - return "$0.00"; + return "$0"; }; -export function BuyCreditsContent() { +export function BuyTokensContent() { const params = useParams(); const searchSpaceId = Number(params?.search_space_id); const [quantity, setQuantity] = useState(1); - // Raw text of the amount field so the user can clear it while typing; - // committed back to a clamped integer on blur. - const [amountInput, setAmountInput] = useState("1"); - - const commitQuantity = (value: number) => { - const clamped = clampQuantity(Number.isFinite(value) ? value : MIN_QUANTITY); - setQuantity(clamped); - setAmountInput(String(clamped)); - }; // Server config flag: stays on REST, not per-user. - const { data: creditStatus } = useQuery({ - queryKey: ["credit-status"], - queryFn: () => stripeApiService.getCreditStatus(), + const { data: tokenStatus } = useQuery({ + queryKey: ["token-status"], + queryFn: () => stripeApiService.getTokenStatus(), }); // Live per-user balance via Zero. const [me] = useZeroQuery(queries.user.me({})); const purchaseMutation = useMutation({ - mutationFn: stripeApiService.createCreditCheckoutSession, + mutationFn: stripeApiService.createTokenCheckoutSession, onSuccess: (response) => { window.location.assign(response.checkout_url); }, @@ -75,10 +62,10 @@ export function BuyCreditsContent() { const totalCreditMicros = quantity * CREDIT_PER_PACK_MICROS; const totalPrice = quantity * PRICE_PER_PACK_USD; - if (creditStatus && !creditStatus.credit_buying_enabled) { + if (tokenStatus && !tokenStatus.token_buying_enabled) { return ( <div className="w-full space-y-3 text-center"> - <h2 className="text-xl font-bold tracking-tight">Buy Credits</h2> + <h2 className="text-xl font-bold tracking-tight">Buy Premium Credit</h2> <p className="text-sm text-muted-foreground"> Credit purchases are temporarily unavailable. </p> @@ -86,20 +73,35 @@ export function BuyCreditsContent() { ); } - const balanceMicros = me?.creditMicrosBalance ?? creditStatus?.credit_micros_balance ?? 0; + const used = me?.premiumCreditMicrosUsed ?? 0; + const limit = me?.premiumCreditMicrosLimit ?? 0; + // Mirrors the backend formula in stripe_routes.py (max(0, limit - used)). + const remaining = Math.max(0, limit - used); + const usagePercentage = me ? Math.min((used / Math.max(limit, 1)) * 100, 100) : 0; return ( <div className="w-full space-y-5"> <div className="text-center"> - <h2 className="text-xl font-bold tracking-tight">Buy Credits</h2> + <h2 className="text-xl font-bold tracking-tight">Buy Premium Credit</h2> + <p className="mt-1 text-sm text-muted-foreground"> + $1 buys $1 of credit, billed at provider cost + </p> </div> - <div className="rounded-lg border bg-muted/20 p-3"> - <div className="flex items-center justify-between text-sm"> - <span className="text-muted-foreground">Current balance</span> - <span className="font-semibold tabular-nums">{formatUsd(balanceMicros)}</span> + {me && ( + <div className="rounded-lg border bg-muted/20 p-3 space-y-1.5"> + <div className="flex justify-between items-center text-xs"> + <span className="text-muted-foreground"> + {formatUsd(used)} / {formatUsd(limit)} of credit + </span> + <span className="font-medium">{usagePercentage.toFixed(0)}%</span> + </div> + <Progress value={usagePercentage} className="h-1.5 [&>div]:bg-purple-500" /> + <p className="text-[11px] text-muted-foreground"> + {formatUsd(remaining)} of credit remaining + </p> </div> - </div> + )} <div className="space-y-3"> <div className="flex items-center justify-center gap-3"> @@ -107,39 +109,21 @@ export function BuyCreditsContent() { type="button" variant="ghost" size="icon" - onClick={() => commitQuantity(quantity - 1)} - disabled={quantity <= MIN_QUANTITY || purchaseMutation.isPending} + onClick={() => setQuantity((q) => Math.max(1, q - 1))} + disabled={quantity <= 1 || purchaseMutation.isPending} className="size-8 text-muted-foreground shadow-none transition-colors hover:bg-muted hover:text-white disabled:opacity-40" > <Minus className="h-3.5 w-3.5" /> </Button> - <div className="flex items-baseline gap-1.5"> - <span className="text-lg font-semibold">$</span> - <input - type="text" - inputMode="numeric" - value={amountInput} - onChange={(e) => { - const raw = e.target.value.replace(/[^0-9]/g, ""); - setAmountInput(raw); - const parsed = Number.parseInt(raw, 10); - if (Number.isFinite(parsed)) { - setQuantity(clampQuantity(parsed)); - } - }} - onBlur={() => commitQuantity(Number.parseInt(amountInput, 10))} - disabled={purchaseMutation.isPending} - aria-label="Credit amount in US dollars" - className="w-20 rounded-md border bg-transparent px-2 py-1 text-center text-lg font-semibold tabular-nums outline-none focus:ring-2 focus:ring-ring disabled:opacity-60" - /> - <span className="text-sm text-muted-foreground">of credit</span> - </div> + <span className="min-w-32 text-center text-lg font-semibold tabular-nums"> + ${(totalCreditMicros / 1_000_000).toFixed(0)} of credit + </span> <Button type="button" variant="ghost" size="icon" - onClick={() => commitQuantity(quantity + 1)} - disabled={quantity >= MAX_QUANTITY || purchaseMutation.isPending} + onClick={() => setQuantity((q) => Math.min(100, q + 1))} + disabled={quantity >= 100 || purchaseMutation.isPending} className="size-8 text-muted-foreground shadow-none transition-colors hover:bg-muted hover:text-white disabled:opacity-40" > <Plus className="h-3.5 w-3.5" /> @@ -152,7 +136,7 @@ export function BuyCreditsContent() { key={m} type="button" variant="ghost" - onClick={() => commitQuantity(m)} + onClick={() => setQuantity(m)} disabled={purchaseMutation.isPending} className={cn( "h-auto rounded-md px-2.5 py-1 text-xs font-medium tabular-nums transition-colors disabled:opacity-60", diff --git a/surfsense_web/components/settings/earn-credits-content.tsx b/surfsense_web/components/settings/more-pages-content.tsx similarity index 80% rename from surfsense_web/components/settings/earn-credits-content.tsx rename to surfsense_web/components/settings/more-pages-content.tsx index 731ea7726..e1b05f4d2 100644 --- a/surfsense_web/components/settings/earn-credits-content.tsx +++ b/surfsense_web/components/settings/more-pages-content.tsx @@ -22,14 +22,7 @@ import { } from "@/lib/posthog/events"; import { cn } from "@/lib/utils"; -// Compact dollar label for a task's reward (e.g. "+$0.03"). -const formatRewardUsd = (micros: number) => { - const dollars = micros / 1_000_000; - if (dollars >= 1) return `+$${dollars.toFixed(2)}`; - return `+$${dollars.toFixed(2)}`; -}; - -export function EarnCreditsContent() { +export function MorePagesContent() { const params = useParams(); const queryClient = useQueryClient(); const searchSpaceId = params?.search_space_id ?? ""; @@ -42,11 +35,11 @@ export function EarnCreditsContent() { queryKey: ["incentive-tasks"], queryFn: () => incentiveTasksApiService.getTasks(), }); - const { data: creditStatus } = useQuery({ - queryKey: ["credit-status"], - queryFn: () => stripeApiService.getCreditStatus(), + const { data: stripeStatus } = useQuery({ + queryKey: ["stripe-status"], + queryFn: () => stripeApiService.getStatus(), }); - const creditBuyingEnabled = creditStatus?.credit_buying_enabled ?? true; + const pageBuyingEnabled = stripeStatus?.page_buying_enabled ?? true; const completeMutation = useMutation({ mutationFn: incentiveTasksApiService.completeTask, @@ -55,7 +48,7 @@ export function EarnCreditsContent() { toast.success(response.message); const task = data?.tasks.find((t) => t.task_type === taskType); if (task) { - trackIncentiveTaskCompleted(taskType, task.credit_micros_reward); + trackIncentiveTaskCompleted(taskType, task.pages_reward); } queryClient.invalidateQueries({ queryKey: ["incentive-tasks"] }); queryClient.invalidateQueries({ queryKey: USER_QUERY_KEY }); @@ -76,12 +69,12 @@ export function EarnCreditsContent() { return ( <div className="w-full space-y-5"> <div className="text-center"> - <h2 className="text-xl font-bold tracking-tight">Earn Credits</h2> - <p className="mt-1 text-sm text-muted-foreground">Earn bonus credits by completing tasks</p> + <h2 className="text-xl font-bold tracking-tight">Get Free Pages</h2> + <p className="mt-1 text-sm text-muted-foreground">Earn bonus pages by completing tasks</p> </div> <div className="space-y-2"> - <h3 className="text-sm font-semibold">Earn Bonus Credits</h3> + <h3 className="text-sm font-semibold">Earn Bonus Pages</h3> {isLoading ? ( <div className="space-y-1.5"> {["github", "reddit", "discord"].map((task) => ( @@ -104,16 +97,14 @@ export function EarnCreditsContent() { <CardContent className="flex items-center gap-3 p-3"> <div className={cn( - "flex h-9 min-w-9 shrink-0 items-center justify-center rounded-full px-2", + "flex h-8 w-8 shrink-0 items-center justify-center rounded-full", task.completed ? "bg-primary text-primary-foreground" : "bg-muted" )} > {task.completed ? ( <Check className="h-3.5 w-3.5" /> ) : ( - <span className="text-[11px] font-semibold tabular-nums"> - {formatRewardUsd(task.credit_micros_reward)} - </span> + <span className="text-xs font-semibold">+{task.pages_reward}</span> )} </div> <p @@ -160,13 +151,15 @@ export function EarnCreditsContent() { <div className="text-center"> <p className="text-sm text-muted-foreground">Need more?</p> - {creditBuyingEnabled ? ( + {pageBuyingEnabled ? ( <Button asChild variant="link" className="text-emerald-600 dark:text-emerald-400"> - <Link href={`/dashboard/${searchSpaceId}/buy-more`}>Buy credits at $1 per $1</Link> + <Link href={`/dashboard/${searchSpaceId}/buy-pages`}> + Buy page packs at $1 per 1,000 + </Link> </Button> ) : ( <p className="text-xs text-muted-foreground"> - Credit purchases are temporarily unavailable. + Page purchases are temporarily unavailable. </p> )} </div> diff --git a/surfsense_web/components/tool-ui/audio.tsx b/surfsense_web/components/tool-ui/audio.tsx index cf78298b5..aeadae45b 100644 --- a/surfsense_web/components/tool-ui/audio.tsx +++ b/surfsense_web/components/tool-ui/audio.tsx @@ -201,7 +201,7 @@ export function Audio({ id, src, title, durationMs, className }: AudioProps) { <div className="mx-5 h-px bg-border/50" /> <div className="px-5 pt-3 pb-4 space-y-3"> - <div className="space-y-2"> + <div className="space-y-0.5"> <Slider value={[currentTime]} max={duration || 100} diff --git a/surfsense_web/components/tool-ui/generate-podcast.tsx b/surfsense_web/components/tool-ui/generate-podcast.tsx new file mode 100644 index 000000000..2a62785e8 --- /dev/null +++ b/surfsense_web/components/tool-ui/generate-podcast.tsx @@ -0,0 +1,468 @@ +"use client"; + +import type { ToolCallMessagePartProps } from "@assistant-ui/react"; +import { useParams, usePathname } from "next/navigation"; +import { useCallback, useEffect, useRef, useState } from "react"; +import { z } from "zod"; +import { TextShimmerLoader } from "@/components/prompt-kit/loader"; +import { Audio } from "@/components/tool-ui/audio"; +import { + Accordion, + AccordionContent, + AccordionItem, + AccordionTrigger, +} from "@/components/ui/accordion"; +import { baseApiService } from "@/lib/apis/base-api.service"; +import { authenticatedFetch } from "@/lib/auth-utils"; +import { clearActivePodcastTaskId, setActivePodcastTaskId } from "@/lib/chat/podcast-state"; +import { BACKEND_URL } from "@/lib/env-config"; + +/** + * Zod schemas for runtime validation + */ +const GeneratePodcastArgsSchema = z.object({ + source_content: z.string(), + podcast_title: z.string().nullish(), + user_prompt: z.string().nullish(), +}); + +const GeneratePodcastResultSchema = z.object({ + // Support both old and new status values for backwards compatibility + status: z.enum([ + "pending", + "generating", + "ready", + "failed", + // Legacy values from old saved chats + "processing", + "already_generating", + "success", + "error", + ]), + podcast_id: z.number().nullish(), + task_id: z.string().nullish(), // Legacy field for old saved chats + title: z.string().nullish(), + transcript_entries: z.number().nullish(), + message: z.string().nullish(), + error: z.string().nullish(), +}); + +const PodcastStatusResponseSchema = z.object({ + status: z.enum(["pending", "generating", "ready", "failed"]), + id: z.number(), + title: z.string(), + transcript_entries: z.number().nullish(), + error: z.string().nullish(), +}); + +const PodcastTranscriptEntrySchema = z.object({ + speaker_id: z.number(), + dialog: z.string(), +}); + +const PodcastDetailsSchema = z.object({ + podcast_transcript: z.array(PodcastTranscriptEntrySchema).nullish(), +}); + +/** + * Types derived from Zod schemas + */ +type GeneratePodcastArgs = z.infer<typeof GeneratePodcastArgsSchema>; +type GeneratePodcastResult = z.infer<typeof GeneratePodcastResultSchema>; +type PodcastStatusResponse = z.infer<typeof PodcastStatusResponseSchema>; +type PodcastTranscriptEntry = z.infer<typeof PodcastTranscriptEntrySchema>; + +/** + * Parse and validate podcast status response + */ +function parsePodcastStatusResponse(data: unknown): PodcastStatusResponse | null { + const result = PodcastStatusResponseSchema.safeParse(data); + if (!result.success) { + console.warn("Invalid podcast status response:", result.error.issues); + return null; + } + return result.data; +} + +/** + * Parse and validate podcast details + */ +function parsePodcastDetails(data: unknown): { podcast_transcript?: PodcastTranscriptEntry[] } { + const result = PodcastDetailsSchema.safeParse(data); + if (!result.success) { + console.warn("Invalid podcast details:", result.error.issues); + return {}; + } + return { + podcast_transcript: result.data.podcast_transcript ?? undefined, + }; +} + +function PodcastGeneratingState({ title }: { title: string }) { + return ( + <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> + <div className="px-5 pt-5 pb-4"> + <p className="text-sm font-semibold text-foreground line-clamp-2">{title}</p> + <TextShimmerLoader text="Generating podcast" size="sm" /> + </div> + </div> + ); +} + +function PodcastErrorState({ title, error }: { title: string; error: string }) { + return ( + <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> + <div className="px-5 pt-5 pb-4"> + <p className="text-sm font-semibold text-destructive">Podcast Generation Failed</p> + </div> + <div className="mx-5 h-px bg-border/50" /> + <div className="px-5 py-4"> + <p className="text-sm font-medium text-foreground line-clamp-2">{title}</p> + <p className="text-sm text-muted-foreground mt-1">{error}</p> + </div> + </div> + ); +} + +function AudioLoadingState({ title }: { title: string }) { + return ( + <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> + <div className="px-5 pt-5 pb-4"> + <p className="text-sm font-semibold text-foreground line-clamp-2">{title}</p> + <TextShimmerLoader text="Loading audio" size="sm" /> + </div> + </div> + ); +} + +function PodcastPlayer({ + podcastId, + title, + durationMs, +}: { + podcastId: number; + title: string; + durationMs?: number; +}) { + const params = useParams(); + const pathname = usePathname(); + const isPublicRoute = pathname?.startsWith("/public/"); + const shareToken = isPublicRoute && typeof params?.token === "string" ? params.token : null; + + const [audioSrc, setAudioSrc] = useState<string | null>(null); + const [transcript, setTranscript] = useState<PodcastTranscriptEntry[] | null>(null); + const [isLoading, setIsLoading] = useState(true); + const [error, setError] = useState<string | null>(null); + const objectUrlRef = useRef<string | null>(null); + + // Cleanup object URL on unmount + useEffect(() => { + return () => { + if (objectUrlRef.current) { + URL.revokeObjectURL(objectUrlRef.current); + } + }; + }, []); + + // Fetch audio and podcast details (including transcript) + const loadPodcast = useCallback(async () => { + setIsLoading(true); + setError(null); + + try { + // Revoke previous object URL if exists + if (objectUrlRef.current) { + URL.revokeObjectURL(objectUrlRef.current); + objectUrlRef.current = null; + } + + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), 60000); // 60s timeout + + try { + let audioBlob: Blob; + let rawPodcastDetails: unknown = null; + + if (shareToken) { + // Public view - use public endpoints (baseApiService handles no-auth for /api/v1/public/) + const [blob, details] = await Promise.all([ + baseApiService.getBlob(`/api/v1/public/${shareToken}/podcasts/${podcastId}/stream`), + baseApiService.get(`/api/v1/public/${shareToken}/podcasts/${podcastId}`), + ]); + audioBlob = blob; + rawPodcastDetails = details; + } else { + // Authenticated view - fetch audio and details in parallel + const [audioResponse, details] = await Promise.all([ + authenticatedFetch(`${BACKEND_URL}/api/v1/podcasts/${podcastId}/audio`, { + method: "GET", + signal: controller.signal, + }), + baseApiService.get<unknown>(`/api/v1/podcasts/${podcastId}`), + ]); + + if (!audioResponse.ok) { + throw new Error(`Failed to load audio: ${audioResponse.status}`); + } + + audioBlob = await audioResponse.blob(); + rawPodcastDetails = details; + } + + // Create object URL from blob + const objectUrl = URL.createObjectURL(audioBlob); + objectUrlRef.current = objectUrl; + setAudioSrc(objectUrl); + + // Parse and validate podcast details, then set transcript + if (rawPodcastDetails) { + const podcastDetails = parsePodcastDetails(rawPodcastDetails); + if (podcastDetails.podcast_transcript) { + setTranscript(podcastDetails.podcast_transcript); + } + } + } finally { + clearTimeout(timeoutId); + } + } catch (err) { + console.error("Error loading podcast:", err); + if (err instanceof DOMException && err.name === "AbortError") { + setError("Request timed out. Please try again."); + } else { + setError(err instanceof Error ? err.message : "Failed to load podcast"); + } + } finally { + setIsLoading(false); + } + }, [podcastId, shareToken]); + + // Load podcast when component mounts + useEffect(() => { + loadPodcast(); + }, [loadPodcast]); + + if (isLoading) { + return <AudioLoadingState title={title} />; + } + + if (error || !audioSrc) { + return <PodcastErrorState title={title} error={error || "Failed to load audio"} />; + } + + const hasTranscript = transcript && transcript.length > 0; + + return ( + <div className="my-4"> + <Audio + id={`podcast-${podcastId}`} + src={audioSrc} + title={title} + durationMs={durationMs} + className={hasTranscript ? "rounded-b-none border-b-0" : undefined} + /> + {hasTranscript && ( + <div className="max-w-lg overflow-hidden rounded-b-2xl border border-t-0 bg-muted/30 select-none"> + <div className="mx-5 h-px bg-border/50" /> + <Accordion type="single" collapsible className="px-5"> + <AccordionItem value="transcript" className="border-b-0"> + <AccordionTrigger className="py-3 text-xs sm:text-sm font-medium text-muted-foreground hover:text-accent-foreground hover:no-underline"> + View transcript + </AccordionTrigger> + <AccordionContent className="pb-0"> + <div className="space-y-2 max-h-64 sm:max-h-96 overflow-y-auto select-text"> + {transcript.map((entry, idx) => ( + <div key={`${idx}-${entry.speaker_id}`} className="text-xs sm:text-sm"> + <span className="font-medium text-primary"> + Speaker {entry.speaker_id + 1}: + </span>{" "} + <span className="text-muted-foreground">{entry.dialog}</span> + </div> + ))} + </div> + </AccordionContent> + </AccordionItem> + </Accordion> + </div> + )} + </div> + ); +} + +/** + * Polling component that checks podcast status and shows player when ready + */ +function PodcastStatusPoller({ podcastId, title }: { podcastId: number; title: string }) { + const [podcastStatus, setPodcastStatus] = useState<PodcastStatusResponse | null>(null); + const pollingRef = useRef<NodeJS.Timeout | null>(null); + + // Set active podcast state when this component mounts + useEffect(() => { + setActivePodcastTaskId(String(podcastId)); + + // Clear when component unmounts + return () => { + clearActivePodcastTaskId(); + }; + }, [podcastId]); + + // Poll for podcast status + useEffect(() => { + const pollStatus = async () => { + try { + const rawResponse = await baseApiService.get<unknown>(`/api/v1/podcasts/${podcastId}`); + const response = parsePodcastStatusResponse(rawResponse); + if (response) { + setPodcastStatus(response); + + // Stop polling if podcast is ready or failed + if (response.status === "ready" || response.status === "failed") { + if (pollingRef.current) { + clearInterval(pollingRef.current); + pollingRef.current = null; + } + clearActivePodcastTaskId(); + } + } + } catch (err) { + console.error("Error polling podcast status:", err); + // Don't stop polling on network errors, continue polling + } + }; + + // Initial poll + pollStatus(); + + // Poll every 5 seconds + pollingRef.current = setInterval(pollStatus, 5000); + + return () => { + if (pollingRef.current) { + clearInterval(pollingRef.current); + } + }; + }, [podcastId]); + + // Show loading state while pending or generating + if ( + !podcastStatus || + podcastStatus.status === "pending" || + podcastStatus.status === "generating" + ) { + return <PodcastGeneratingState title={title} />; + } + + // Show error state + if (podcastStatus.status === "failed") { + return <PodcastErrorState title={title} error={podcastStatus.error || "Generation failed"} />; + } + + // Show player when ready + if (podcastStatus.status === "ready") { + return <PodcastPlayer podcastId={podcastStatus.id} title={podcastStatus.title || title} />; + } + + // Fallback + return <PodcastErrorState title={title} error="Unexpected state" />; +} + +/** + * Generate Podcast Tool UI Component + * + * This component is registered with assistant-ui to render custom UI + * when the generate_podcast tool is called by the agent. + * + * It polls for task completion and auto-updates when the podcast is ready. + */ +export const GeneratePodcastToolUI = ({ + args, + result, + status, +}: ToolCallMessagePartProps<GeneratePodcastArgs, GeneratePodcastResult>) => { + const title = args.podcast_title || "SurfSense Podcast"; + + // Loading state - tool is still running (agent processing) + if (status.type === "running" || status.type === "requires-action") { + return <PodcastGeneratingState title={title} />; + } + + // Incomplete/cancelled state + if (status.type === "incomplete") { + if (status.reason === "cancelled") { + return ( + <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> + <div className="px-5 pt-5 pb-4"> + <p className="text-sm font-semibold text-muted-foreground">Podcast Cancelled</p> + <p className="text-xs text-muted-foreground mt-0.5">Podcast generation was cancelled</p> + </div> + </div> + ); + } + if (status.reason === "error") { + return ( + <PodcastErrorState + title={title} + error={typeof status.error === "string" ? status.error : "An error occurred"} + /> + ); + } + } + + // No result yet + if (!result) { + return <PodcastGeneratingState title={title} />; + } + + // Failed result (new: "failed", legacy: "error") + if (result.status === "failed" || result.status === "error") { + return <PodcastErrorState title={title} error={result.error || "Generation failed"} />; + } + + // Pending/generating rows have a stable podcast_id, so the card can poll + // independently while the chat stream finishes. + if ( + (result.status === "pending" || + result.status === "generating" || + result.status === "processing") && + result.podcast_id + ) { + return <PodcastStatusPoller podcastId={result.podcast_id} title={result.title || title} />; + } + + // Legacy duplicate/no-ID result - show a simple warning, don't create + // another poller. The first tool call will display the podcast when ready. + if (result.status === "generating" || result.status === "already_generating") { + return ( + <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> + <div className="px-5 pt-5 pb-4"> + <p className="text-sm font-semibold text-foreground">Podcast already in progress</p> + <p className="text-xs text-muted-foreground mt-0.5"> + Please wait for the current podcast to complete. + </p> + </div> + </div> + ); + } + + // Ready with podcast_id (new: "ready", legacy: "success") + if ((result.status === "ready" || result.status === "success") && result.podcast_id) { + return <PodcastPlayer podcastId={result.podcast_id} title={result.title || title} />; + } + + // Legacy: old chats with Celery task_id (status: "processing" or "success" without podcast_id) + // These can't be recovered since the old task polling endpoint no longer exists + if (result.task_id && !result.podcast_id) { + return ( + <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> + <div className="px-5 pt-5 pb-4"> + <p className="text-sm font-semibold text-muted-foreground">Podcast Unavailable</p> + <p className="text-xs text-muted-foreground mt-0.5"> + This podcast was generated with an older version. Please generate a new one. + </p> + </div> + </div> + ); + } + + // Fallback - missing required data + return <PodcastErrorState title={title} error="Missing podcast ID" />; +}; diff --git a/surfsense_web/components/tool-ui/index.ts b/surfsense_web/components/tool-ui/index.ts index a6576f065..ee5072dad 100644 --- a/surfsense_web/components/tool-ui/index.ts +++ b/surfsense_web/components/tool-ui/index.ts @@ -16,6 +16,7 @@ export { GenerateImageResultSchema, GenerateImageToolUI, } from "./generate-image"; +export { GeneratePodcastToolUI } from "./generate-podcast"; export { GenerateReportToolUI } from "./generate-report"; export { CreateGoogleDriveFileToolUI, DeleteGoogleDriveFileToolUI } from "./google-drive"; export { @@ -43,7 +44,6 @@ export { type SerializablePlan, type TodoStatus, } from "./plan"; -export { GeneratePodcastToolUI } from "./podcast"; export { type ExecuteArgs, ExecuteArgsSchema, diff --git a/surfsense_web/components/tool-ui/podcast/brief-review.tsx b/surfsense_web/components/tool-ui/podcast/brief-review.tsx deleted file mode 100644 index 3473b64d6..000000000 --- a/surfsense_web/components/tool-ui/podcast/brief-review.tsx +++ /dev/null @@ -1,399 +0,0 @@ -"use client"; - -import { Loader2, Plus, Trash2 } from "lucide-react"; -import { useEffect, useMemo, useState } from "react"; -import { toast } from "sonner"; -import { Button } from "@/components/ui/button"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; -import { Textarea } from "@/components/ui/textarea"; -import { - MAX_SPEAKERS, - type PodcastSpec, - type PodcastStyle, - podcastStyle, - type SpeakerRole, - speakerRole, - type VoiceOption, -} from "@/contracts/types/podcast.types"; -import type { LivePodcast } from "@/hooks/use-podcast-live"; -import { podcastsApiService } from "@/lib/apis/podcasts-api.service"; -import { AppError } from "@/lib/error"; -import { VoicePreviewButton } from "./voice-preview-button"; - -// A "*" voice speaks whatever language the text is in (mirrors the backend -// catalog's ANY_LANGUAGE sentinel). -const ANY_LANGUAGE = "*"; - -function speaks(voice: VoiceOption, language: string): boolean { - if (voice.language === ANY_LANGUAGE) return true; - return primary(voice.language) === primary(language); -} - -function primary(language: string): string { - return language.split("-", 1)[0].trim().toLowerCase(); -} - -interface BriefReviewProps { - podcast: LivePodcast; - spec: PodcastSpec; -} - -/** - * The brief gate, rendered inline in the chat card: a pre-filled - * near-confirmation. One-click approve is the easy path; every field stays - * overridable and saves through the version-guarded PATCH so concurrent edits - * surface instead of clobbering. Approval needs no local follow-up — the - * pushed status flips the card to its drafting state. - */ -export function BriefReview({ podcast, spec }: BriefReviewProps) { - const [draft, setDraft] = useState<PodcastSpec>(spec); - const [voices, setVoices] = useState<VoiceOption[] | null>(null); - const [isSubmitting, setIsSubmitting] = useState(false); - - // A pushed spec change (saved edit or concurrent editor) resets the form to - // the authoritative version. - // biome-ignore lint/correctness/useExhaustiveDependencies: reset only when the server version moves - useEffect(() => { - setDraft(spec); - }, [podcast.specVersion]); - - useEffect(() => { - let cancelled = false; - podcastsApiService - .listVoices() - .then((catalog) => { - if (!cancelled) setVoices(catalog); - }) - .catch(() => { - if (!cancelled) setVoices([]); - }); - return () => { - cancelled = true; - }; - }, []); - - const languages = useMemo(() => { - const tags = new Set<string>(); - for (const voice of voices ?? []) { - if (voice.language !== ANY_LANGUAGE) tags.add(voice.language); - } - tags.add(draft.language); - return [...tags].sort(); - }, [voices, draft.language]); - - const voicesForLanguage = useMemo( - () => (voices ?? []).filter((voice) => speaks(voice, draft.language)), - [voices, draft.language] - ); - - const isDirty = useMemo(() => JSON.stringify(draft) !== JSON.stringify(spec), [draft, spec]); - - const setLanguage = (language: string) => { - setDraft((current) => { - const candidates = (voices ?? []).filter((voice) => speaks(voice, language)); - // Voices that can't render the new language are remapped so the saved - // spec never pairs a language with an incompatible voice. - const speakers = current.speakers.map((speaker, index) => { - const stillValid = candidates.some((voice) => voice.voice_id === speaker.voice_id); - const fallback = candidates[index % Math.max(candidates.length, 1)]; - return stillValid || !fallback ? speaker : { ...speaker, voice_id: fallback.voice_id }; - }); - return { ...current, language, speakers }; - }); - }; - - const setStyle = (style: PodcastStyle) => { - setDraft((current) => ({ - ...current, - style, - // A monologue has exactly one speaker, so extra speakers are dropped - // rather than letting approval fail validation. - speakers: style === "monologue" ? current.speakers.slice(0, 1) : current.speakers, - })); - }; - - const updateSpeaker = (slot: number, change: Partial<PodcastSpec["speakers"][number]>) => { - setDraft((current) => ({ - ...current, - speakers: current.speakers.map((speaker) => - speaker.slot === slot ? { ...speaker, ...change } : speaker - ), - })); - }; - - const addSpeaker = () => { - setDraft((current) => { - if (current.speakers.length >= MAX_SPEAKERS) return current; - const slot = Math.max(...current.speakers.map((s) => s.slot)) + 1; - const voice = - voicesForLanguage[current.speakers.length % Math.max(voicesForLanguage.length, 1)]; - return { - ...current, - speakers: [ - ...current.speakers, - { - slot, - name: `Speaker ${current.speakers.length + 1}`, - role: "guest" as SpeakerRole, - voice_id: voice?.voice_id ?? current.speakers[0].voice_id, - }, - ], - }; - }); - }; - - const removeSpeaker = (slot: number) => { - setDraft((current) => { - if (current.speakers.length <= 1) return current; - return { - ...current, - speakers: current.speakers.filter((speaker) => speaker.slot !== slot), - }; - }); - }; - - const saveIfDirty = async (): Promise<boolean> => { - if (!isDirty) return true; - try { - await podcastsApiService.updateSpec(podcast.id, draft, podcast.specVersion); - return true; - } catch (error) { - if (error instanceof AppError && error.status === 409) { - toast.warning("The brief changed elsewhere — reloaded the latest version."); - setDraft(spec); - } else { - toast.error(error instanceof Error ? error.message : "Failed to save the brief"); - } - return false; - } - }; - - const handleApprove = async () => { - setIsSubmitting(true); - try { - if (!(await saveIfDirty())) return; - await podcastsApiService.approveBrief(podcast.id); - } catch (error) { - toast.error(error instanceof Error ? error.message : "Failed to approve the brief"); - } finally { - setIsSubmitting(false); - } - }; - - return ( - <div className="flex flex-col gap-6"> - <div className="grid grid-cols-2 gap-4"> - <div className="flex flex-col gap-2"> - <Label htmlFor="podcast-language">Language</Label> - <Select value={draft.language} onValueChange={setLanguage}> - <SelectTrigger id="podcast-language"> - <SelectValue placeholder="Language" /> - </SelectTrigger> - <SelectContent> - {languages.map((tag) => ( - <SelectItem key={tag} value={tag}> - {languageLabel(tag)} - </SelectItem> - ))} - </SelectContent> - </Select> - </div> - <div className="flex flex-col gap-2"> - <Label htmlFor="podcast-style">Style</Label> - <Select value={draft.style} onValueChange={(value) => setStyle(value as PodcastStyle)}> - <SelectTrigger id="podcast-style"> - <SelectValue placeholder="Style" /> - </SelectTrigger> - <SelectContent> - {podcastStyle.options.map((style) => ( - <SelectItem key={style} value={style}> - {capitalize(style)} - </SelectItem> - ))} - </SelectContent> - </Select> - </div> - </div> - - <div className="flex flex-col gap-3"> - <div className="flex items-center justify-between"> - <Label>Speakers</Label> - <Button - type="button" - variant="ghost" - size="sm" - onClick={addSpeaker} - disabled={draft.style === "monologue" || draft.speakers.length >= MAX_SPEAKERS} - > - <Plus className="size-4" /> Add speaker - </Button> - </div> - {draft.speakers.map((speaker) => ( - <div key={speaker.slot} className="flex items-end gap-2 rounded-lg border p-3"> - <div className="flex flex-1 flex-col gap-1.5"> - <Label htmlFor={`speaker-name-${speaker.slot}`} className="text-xs"> - Name - </Label> - <Input - id={`speaker-name-${speaker.slot}`} - value={speaker.name} - maxLength={120} - onChange={(e) => updateSpeaker(speaker.slot, { name: e.target.value })} - /> - </div> - <div className="flex w-28 flex-col gap-1.5"> - <Label className="text-xs">Role</Label> - <Select - value={speaker.role} - onValueChange={(value) => - updateSpeaker(speaker.slot, { role: value as SpeakerRole }) - } - > - <SelectTrigger> - <SelectValue /> - </SelectTrigger> - <SelectContent> - {speakerRole.options.map((role) => ( - <SelectItem key={role} value={role}> - {capitalize(role)} - </SelectItem> - ))} - </SelectContent> - </Select> - </div> - <div className="flex w-52 flex-col gap-1.5"> - <Label className="text-xs">Voice</Label> - <div className="flex items-center gap-1"> - <Select - value={speaker.voice_id} - onValueChange={(value) => updateSpeaker(speaker.slot, { voice_id: value })} - > - <SelectTrigger> - <SelectValue placeholder={voices === null ? "Loading…" : "Voice"} /> - </SelectTrigger> - <SelectContent> - {voiceItems(voicesForLanguage, speaker.voice_id).map((voice) => ( - <SelectItem key={voice.voice_id} value={voice.voice_id}> - {voice.display_name} ({voice.gender}) - </SelectItem> - ))} - </SelectContent> - </Select> - <VoicePreviewButton voiceId={speaker.voice_id} /> - </div> - </div> - <Button - type="button" - variant="ghost" - size="icon" - aria-label={`Remove ${speaker.name}`} - onClick={() => removeSpeaker(speaker.slot)} - disabled={draft.speakers.length <= 1} - > - <Trash2 className="size-4" /> - </Button> - </div> - ))} - </div> - - <div className="grid grid-cols-2 gap-4"> - <div className="flex flex-col gap-2"> - <Label htmlFor="podcast-min-minutes">Min length (minutes)</Label> - <Input - id="podcast-min-minutes" - type="number" - min={1} - value={draft.duration.min_minutes} - onChange={(e) => - setDraft((current) => ({ - ...current, - duration: { ...current.duration, min_minutes: Number(e.target.value) || 1 }, - })) - } - /> - </div> - <div className="flex flex-col gap-2"> - <Label htmlFor="podcast-max-minutes">Max length (minutes)</Label> - <Input - id="podcast-max-minutes" - type="number" - min={draft.duration.min_minutes} - value={draft.duration.max_minutes} - onChange={(e) => - setDraft((current) => ({ - ...current, - duration: { - ...current.duration, - max_minutes: Number(e.target.value) || current.duration.min_minutes, - }, - })) - } - /> - </div> - </div> - - <div className="flex flex-col gap-2"> - <Label htmlFor="podcast-focus">Focus (optional)</Label> - <Textarea - id="podcast-focus" - placeholder="What should the episode emphasise?" - maxLength={2000} - value={draft.focus ?? ""} - onChange={(e) => setDraft((current) => ({ ...current, focus: e.target.value || null }))} - /> - </div> - - <div className="flex justify-end gap-2"> - {isDirty ? ( - <Button - type="button" - variant="ghost" - onClick={() => setDraft(spec)} - disabled={isSubmitting} - > - Discard - </Button> - ) : null} - <Button - type="button" - onClick={handleApprove} - disabled={isSubmitting || draft.duration.max_minutes < draft.duration.min_minutes} - > - {isSubmitting ? <Loader2 className="size-4 animate-spin" /> : null} - {isDirty ? "Approve changes & draft transcript" : "Approve & draft transcript"} - </Button> - </div> - </div> - ); -} - -/** The current selection stays listed even when it no longer matches the - * language filter, so the Select never renders an orphaned value. */ -function voiceItems(candidates: VoiceOption[], selectedId: string): VoiceOption[] { - if (candidates.some((voice) => voice.voice_id === selectedId)) return candidates; - return [ - { voice_id: selectedId, display_name: selectedId, language: "", gender: "unknown" }, - ...candidates, - ]; -} - -function languageLabel(tag: string): string { - try { - const label = new Intl.DisplayNames(["en"], { type: "language" }).of(tag); - return label && label !== tag ? `${label} (${tag})` : tag; - } catch { - return tag; - } -} - -function capitalize(value: string): string { - return value.charAt(0).toUpperCase() + value.slice(1); -} diff --git a/surfsense_web/components/tool-ui/podcast/generate-podcast.tsx b/surfsense_web/components/tool-ui/podcast/generate-podcast.tsx deleted file mode 100644 index f881be9dd..000000000 --- a/surfsense_web/components/tool-ui/podcast/generate-podcast.tsx +++ /dev/null @@ -1,371 +0,0 @@ -"use client"; - -import type { ToolCallMessagePartProps } from "@assistant-ui/react"; -import { Loader2, RotateCcw, Undo2, X } from "lucide-react"; -import { usePathname } from "next/navigation"; -import { type ReactNode, useEffect, useState } from "react"; -import { toast } from "sonner"; -import { TextShimmerLoader } from "@/components/prompt-kit/loader"; -import { - AlertDialog, - AlertDialogAction, - AlertDialogCancel, - AlertDialogContent, - AlertDialogDescription, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogTitle, - AlertDialogTrigger, -} from "@/components/ui/alert-dialog"; -import { Button, buttonVariants } from "@/components/ui/button"; -import { type LivePodcast, usePodcastLive } from "@/hooks/use-podcast-live"; -import { podcastsApiService } from "@/lib/apis/podcasts-api.service"; -import { BriefReview } from "./brief-review"; -import { PodcastErrorState, PodcastPlayer } from "./player"; -import type { GeneratePodcastArgs, GeneratePodcastResult } from "./schema"; - -function WorkingState({ - title, - label, - action, -}: { - title: string; - label: string; - action?: ReactNode; -}) { - return ( - <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> - <div className="flex items-start justify-between gap-3 px-5 pt-5 pb-4"> - <div className="min-w-0"> - <p className="text-sm font-semibold text-foreground line-clamp-2">{title}</p> - <TextShimmerLoader text={label} size="sm" /> - </div> - {action} - </div> - </div> - ); -} - -function NoticeState({ title, message }: { title: string; message: string }) { - return ( - <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> - <div className="px-5 pt-5 pb-4"> - <p className="text-sm font-semibold text-muted-foreground">{title}</p> - <p className="text-xs text-muted-foreground mt-0.5">{message}</p> - </div> - </div> - ); -} - -/** - * Regenerating reopens the brief and ultimately replaces the current audio, - * so a stray click is guarded by an inline confirm step. - */ -function RegenerateButton({ podcast }: { podcast: LivePodcast }) { - const [confirming, setConfirming] = useState(false); - const [isSubmitting, setIsSubmitting] = useState(false); - - const regenerate = async () => { - setIsSubmitting(true); - try { - await podcastsApiService.regenerate(podcast.id); - } catch (error) { - toast.error(error instanceof Error ? error.message : "Failed to regenerate the podcast"); - } finally { - setIsSubmitting(false); - setConfirming(false); - } - }; - - if (!confirming) { - return ( - <Button - type="button" - variant="ghost" - size="sm" - className="text-muted-foreground" - onClick={() => setConfirming(true)} - > - <RotateCcw className="size-3.5" /> Regenerate - </Button> - ); - } - - return ( - <div className="flex items-center gap-2"> - <span className="text-xs text-muted-foreground"> - Reopen the brief and replace this episode? - </span> - <Button - type="button" - variant="ghost" - size="sm" - onClick={() => setConfirming(false)} - disabled={isSubmitting} - > - Keep it - </Button> - <Button - type="button" - variant="destructive" - size="sm" - onClick={regenerate} - disabled={isSubmitting} - > - {isSubmitting ? <Loader2 className="size-3.5 animate-spin" /> : null} - Regenerate - </Button> - </div> - ); -} - -/** - * The way out of an in-flight generation depends on what already exists: - * a regeneration is reverted (the stored episode survives, so no confirm), - * while a first-time generation is cancelled (destructive, so confirmed via a - * dialog — the card header is too cramped to host a confirmation row). - */ -function BackOutButton({ podcastId, hasEpisode }: { podcastId: number; hasEpisode: boolean }) { - const [isSubmitting, setIsSubmitting] = useState(false); - - const run = async (call: (id: number) => Promise<unknown>, failure: string) => { - setIsSubmitting(true); - try { - await call(podcastId); - } catch (error) { - toast.error(error instanceof Error ? error.message : failure); - } finally { - setIsSubmitting(false); - } - }; - - if (hasEpisode) { - return ( - <Button - type="button" - variant="ghost" - size="sm" - className="shrink-0 text-muted-foreground" - disabled={isSubmitting} - onClick={() => - run(podcastsApiService.revertRegeneration, "Failed to restore the current episode") - } - > - {isSubmitting ? ( - <Loader2 className="size-3.5 animate-spin" /> - ) : ( - <Undo2 className="size-3.5" /> - )} - Keep current episode - </Button> - ); - } - - return ( - <AlertDialog> - <AlertDialogTrigger asChild> - <Button - type="button" - variant="ghost" - size="sm" - className="shrink-0 text-muted-foreground" - disabled={isSubmitting} - > - <X className="size-3.5" /> Cancel - </Button> - </AlertDialogTrigger> - <AlertDialogContent> - <AlertDialogHeader> - <AlertDialogTitle>Cancel this podcast?</AlertDialogTitle> - <AlertDialogDescription> - Generation stops and the podcast is discarded. This cannot be undone. - </AlertDialogDescription> - </AlertDialogHeader> - <AlertDialogFooter> - <AlertDialogCancel>Keep going</AlertDialogCancel> - <AlertDialogAction - className={buttonVariants({ variant: "destructive" })} - onClick={() => run(podcastsApiService.cancel, "Failed to cancel the podcast")} - > - Cancel podcast - </AlertDialogAction> - </AlertDialogFooter> - </AlertDialogContent> - </AlertDialog> - ); -} - -const BACK_OUT_STATUSES = new Set(["awaiting_brief", "drafting", "rendering"]); - -/** Status-driven card for an authenticated viewer, fed by Zero push. */ -function LivePodcastCard({ - podcastId, - fallbackTitle, -}: { - podcastId: number; - fallbackTitle: string; -}) { - const { podcast, isLoading } = usePodcastLive(podcastId); - - // Whether a finished episode exists decides revert-vs-cancel, and Zero - // doesn't publish audio fields — so the in-flight states check over REST, - // re-checking on each status change (a fresh podcast gains its episode, - // a regeneration starts with one). - const status = podcast?.status; - const [hasEpisode, setHasEpisode] = useState(false); - useEffect(() => { - if (!status || !BACK_OUT_STATUSES.has(status)) return; - let stale = false; - podcastsApiService - .getDetail(podcastId) - .then((detail) => { - if (!stale) setHasEpisode(detail.has_audio); - }) - .catch(() => {}); - return () => { - stale = true; - }; - }, [podcastId, status]); - - if (!podcast) { - if (isLoading) { - return <WorkingState title={fallbackTitle} label="Loading podcast" />; - } - return ( - <NoticeState - title="Podcast Unavailable" - message="This podcast no longer exists or you don't have access to it." - /> - ); - } - - const title = podcast.title || fallbackTitle; - - const backOut = <BackOutButton podcastId={podcast.id} hasEpisode={hasEpisode} />; - - switch (podcast.status) { - case "pending": - return <WorkingState title={title} label="Preparing brief" />; - case "drafting": - return <WorkingState title={title} label="Drafting transcript" action={backOut} />; - case "rendering": - return <WorkingState title={title} label="Rendering audio" action={backOut} />; - case "awaiting_brief": - // The gate lives right in the chat: the form is the card, so there - // is nothing to open and nothing to dismiss. - if (!podcast.spec) { - return <WorkingState title={title} label="Preparing brief" />; - } - return ( - <div className="my-4 max-w-xl overflow-hidden rounded-2xl border bg-muted/30"> - <div className="flex items-start justify-between gap-3 px-5 pt-5 pb-3 select-none"> - <div className="min-w-0"> - <p className="text-sm font-semibold text-foreground line-clamp-2">{title}</p> - <p className="text-xs text-muted-foreground mt-0.5"> - Confirm the language, voices, and length — the episode generates automatically after - you approve. - </p> - </div> - {backOut} - </div> - <div className="mx-5 h-px bg-border/50" /> - <div className="px-5 py-4"> - <BriefReview podcast={podcast} spec={podcast.spec} /> - </div> - </div> - ); - case "awaiting_review": - // Legacy rows parked at the removed transcript gate; the only way - // forward is regenerating through the brief gate. - return ( - <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> - <div className="px-5 pt-5 pb-4"> - <p className="text-sm font-semibold text-foreground line-clamp-2">{title}</p> - <p className="text-xs text-muted-foreground mt-0.5"> - This podcast was drafted before audio rendering became automatic. - </p> - </div> - <div className="mx-5 h-px bg-border/50" /> - <div className="flex justify-end px-5 py-3"> - <RegenerateButton podcast={podcast} /> - </div> - </div> - ); - case "ready": - return ( - <div> - <PodcastPlayer - podcastId={podcast.id} - title={title} - durationMs={podcast.durationSeconds ? podcast.durationSeconds * 1000 : undefined} - /> - <div className="-mt-2 mb-4 flex max-w-lg justify-end"> - <RegenerateButton podcast={podcast} /> - </div> - </div> - ); - case "failed": - return <PodcastErrorState title={title} error={podcast.error || "Generation failed"} />; - case "cancelled": - return <NoticeState title="Podcast Cancelled" message="This podcast was cancelled." />; - } -} - -/** - * Tool UI for `generate_podcast`. The tool only prepares the podcast (it - * returns with the brief awaiting review), so this card follows the lifecycle - * by Zero push, rendering the brief form inline at the gate. Public shared - * chats have no Zero session; their snapshots only ever contain finished - * episodes, so the player renders directly against the share-token endpoints. - */ -export const GeneratePodcastToolUI = ({ - args, - result, - status, -}: ToolCallMessagePartProps<GeneratePodcastArgs, GeneratePodcastResult>) => { - const pathname = usePathname(); - const isPublicRoute = !!pathname?.startsWith("/public/"); - const title = args.podcast_title || "SurfSense Podcast"; - - if (status.type === "running" || status.type === "requires-action") { - return <WorkingState title={title} label="Preparing podcast" />; - } - - if (status.type === "incomplete") { - if (status.reason === "cancelled") { - return <NoticeState title="Podcast Cancelled" message="Podcast preparation was cancelled." />; - } - if (status.reason === "error") { - return ( - <PodcastErrorState - title={title} - error={typeof status.error === "string" ? status.error : "An error occurred"} - /> - ); - } - } - - if (!result) { - return <WorkingState title={title} label="Preparing podcast" />; - } - - if (result.podcast_id) { - if (isPublicRoute) { - return <PodcastPlayer podcastId={result.podcast_id} title={result.title || title} />; - } - return <LivePodcastCard podcastId={result.podcast_id} fallbackTitle={result.title || title} />; - } - - if (result.status === "failed" || result.status === "error") { - return <PodcastErrorState title={title} error={result.error || "Generation failed"} />; - } - - // Legacy saved chats: results identified only by a Celery task id can't be - // recovered through the lifecycle API. - return ( - <NoticeState - title="Podcast Unavailable" - message="This podcast was generated with an older version. Please generate a new one." - /> - ); -}; diff --git a/surfsense_web/components/tool-ui/podcast/index.ts b/surfsense_web/components/tool-ui/podcast/index.ts deleted file mode 100644 index 1e5e5e06e..000000000 --- a/surfsense_web/components/tool-ui/podcast/index.ts +++ /dev/null @@ -1 +0,0 @@ -export { GeneratePodcastToolUI } from "./generate-podcast"; diff --git a/surfsense_web/components/tool-ui/podcast/player.tsx b/surfsense_web/components/tool-ui/podcast/player.tsx deleted file mode 100644 index 2a3746844..000000000 --- a/surfsense_web/components/tool-ui/podcast/player.tsx +++ /dev/null @@ -1,209 +0,0 @@ -"use client"; - -import { useParams, usePathname } from "next/navigation"; -import { useCallback, useEffect, useRef, useState } from "react"; -import { z } from "zod"; -import { TextShimmerLoader } from "@/components/prompt-kit/loader"; -import { Audio } from "@/components/tool-ui/audio"; -import { - Accordion, - AccordionContent, - AccordionItem, - AccordionTrigger, -} from "@/components/ui/accordion"; -import { baseApiService } from "@/lib/apis/base-api.service"; -import { podcastsApiService } from "@/lib/apis/podcasts-api.service"; -import { authenticatedFetch } from "@/lib/auth-utils"; -import { BACKEND_URL } from "@/lib/env-config"; -import { speakerLabel } from "./schema"; - -// Public snapshots predate the transcript.turns shape and keep their own. -const publicPodcastDetailsSchema = z.object({ - podcast_transcript: z.array(z.object({ speaker_id: z.number(), dialog: z.string() })).nullish(), -}); - -interface TranscriptLine { - // Transcripts are immutable once fetched, so a turn's position identifies it. - key: string; - label: string; - text: string; -} - -export function PodcastErrorState({ title, error }: { title: string; error: string }) { - return ( - <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> - <div className="px-5 pt-5 pb-4"> - <p className="text-sm font-semibold text-destructive">Podcast Generation Failed</p> - </div> - <div className="mx-5 h-px bg-border/50" /> - <div className="px-5 py-4"> - <p className="text-sm font-medium text-foreground line-clamp-2">{title}</p> - <p className="text-sm text-muted-foreground mt-1">{error}</p> - </div> - </div> - ); -} - -function AudioLoadingState({ title }: { title: string }) { - return ( - <div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none"> - <div className="px-5 pt-5 pb-4"> - <p className="text-sm font-semibold text-foreground line-clamp-2">{title}</p> - <TextShimmerLoader text="Loading audio" size="sm" /> - </div> - </div> - ); -} - -/** - * Streams the rendered episode and shows its transcript. Works in two modes: - * authenticated (lifecycle stream + detail endpoints) and public shared chat - * (share-token snapshot endpoints), detected from the route. - */ -export function PodcastPlayer({ - podcastId, - title, - durationMs, -}: { - podcastId: number; - title: string; - durationMs?: number; -}) { - const params = useParams(); - const pathname = usePathname(); - const isPublicRoute = pathname?.startsWith("/public/"); - const shareToken = isPublicRoute && typeof params?.token === "string" ? params.token : null; - - const [audioSrc, setAudioSrc] = useState<string | null>(null); - const [transcriptLines, setTranscriptLines] = useState<TranscriptLine[] | null>(null); - const [isLoading, setIsLoading] = useState(true); - const [error, setError] = useState<string | null>(null); - const objectUrlRef = useRef<string | null>(null); - - useEffect(() => { - return () => { - if (objectUrlRef.current) { - URL.revokeObjectURL(objectUrlRef.current); - } - }; - }, []); - - const loadPodcast = useCallback(async () => { - setIsLoading(true); - setError(null); - - try { - if (objectUrlRef.current) { - URL.revokeObjectURL(objectUrlRef.current); - objectUrlRef.current = null; - } - - const controller = new AbortController(); - const timeoutId = setTimeout(() => controller.abort(), 60000); - - try { - let audioBlob: Blob; - let lines: TranscriptLine[] = []; - - if (shareToken) { - const [blob, details] = await Promise.all([ - baseApiService.getBlob(`/api/v1/public/${shareToken}/podcasts/${podcastId}/stream`), - baseApiService.get(`/api/v1/public/${shareToken}/podcasts/${podcastId}`), - ]); - audioBlob = blob; - const parsed = publicPodcastDetailsSchema.safeParse(details); - lines = (parsed.success ? (parsed.data.podcast_transcript ?? []) : []).map( - (entry, turn) => ({ - key: `turn-${turn}`, - label: `Speaker ${entry.speaker_id + 1}`, - text: entry.dialog, - }) - ); - } else { - const [audioResponse, detail] = await Promise.all([ - authenticatedFetch(`${BACKEND_URL}/api/v1/podcasts/${podcastId}/stream`, { - method: "GET", - signal: controller.signal, - }), - podcastsApiService.getDetail(podcastId), - ]); - - if (!audioResponse.ok) { - throw new Error(`Failed to load audio: ${audioResponse.status}`); - } - - audioBlob = await audioResponse.blob(); - lines = (detail.transcript?.turns ?? []).map((entry, turn) => ({ - key: `turn-${turn}`, - label: speakerLabel(detail.spec, entry.speaker), - text: entry.text, - })); - } - - const objectUrl = URL.createObjectURL(audioBlob); - objectUrlRef.current = objectUrl; - setAudioSrc(objectUrl); - setTranscriptLines(lines); - } finally { - clearTimeout(timeoutId); - } - } catch (err) { - console.error("Error loading podcast:", err); - if (err instanceof DOMException && err.name === "AbortError") { - setError("Request timed out. Please try again."); - } else { - setError(err instanceof Error ? err.message : "Failed to load podcast"); - } - } finally { - setIsLoading(false); - } - }, [podcastId, shareToken]); - - useEffect(() => { - loadPodcast(); - }, [loadPodcast]); - - if (isLoading) { - return <AudioLoadingState title={title} />; - } - - if (error || !audioSrc) { - return <PodcastErrorState title={title} error={error || "Failed to load audio"} />; - } - - const hasTranscript = transcriptLines && transcriptLines.length > 0; - - return ( - <div className="my-4"> - <Audio - id={`podcast-${podcastId}`} - src={audioSrc} - title={title} - durationMs={durationMs} - className={hasTranscript ? "rounded-b-none border-b-0" : undefined} - /> - {hasTranscript ? ( - <div className="max-w-lg overflow-hidden rounded-b-2xl border border-t-0 bg-muted/30 select-none"> - <div className="mx-5 h-px bg-border/50" /> - <Accordion type="single" collapsible className="px-5"> - <AccordionItem value="transcript" className="border-b-0"> - <AccordionTrigger className="py-3 text-xs sm:text-sm font-medium text-muted-foreground hover:text-accent-foreground hover:no-underline"> - View transcript - </AccordionTrigger> - <AccordionContent className="pb-0"> - <div className="space-y-2 max-h-64 sm:max-h-96 overflow-y-auto select-text"> - {transcriptLines.map((line) => ( - <div key={line.key} className="text-xs sm:text-sm"> - <span className="font-medium text-primary">{line.label}:</span>{" "} - <span className="text-muted-foreground">{line.text}</span> - </div> - ))} - </div> - </AccordionContent> - </AccordionItem> - </Accordion> - </div> - ) : null} - </div> - ); -} diff --git a/surfsense_web/components/tool-ui/podcast/schema.ts b/surfsense_web/components/tool-ui/podcast/schema.ts deleted file mode 100644 index 91937eaad..000000000 --- a/surfsense_web/components/tool-ui/podcast/schema.ts +++ /dev/null @@ -1,33 +0,0 @@ -import { z } from "zod"; -import type { PodcastSpec } from "@/contracts/types/podcast.types"; - -/** - * Tool-call contract for `generate_podcast`. - * - * The tool prepares a podcast and returns immediately with the row awaiting - * brief review; the card then follows the lifecycle by push. Legacy status - * values are accepted so old saved chats still render something sensible. - */ - -export const generatePodcastArgsSchema = z.object({ - source_content: z.string(), - podcast_title: z.string().nullish(), - user_prompt: z.string().nullish(), -}); -export type GeneratePodcastArgs = z.infer<typeof generatePodcastArgsSchema>; - -export const generatePodcastResultSchema = z.object({ - status: z.string(), - podcast_id: z.number().nullish(), - task_id: z.string().nullish(), // legacy Celery id from old saved chats - title: z.string().nullish(), - message: z.string().nullish(), - error: z.string().nullish(), -}); -export type GeneratePodcastResult = z.infer<typeof generatePodcastResultSchema>; - -/** Display name for the speaker bound to `slot`, falling back to a number. */ -export function speakerLabel(spec: PodcastSpec | null | undefined, slot: number): string { - const speaker = spec?.speakers.find((s) => s.slot === slot); - return speaker?.name ?? `Speaker ${slot + 1}`; -} diff --git a/surfsense_web/components/tool-ui/podcast/voice-preview-button.tsx b/surfsense_web/components/tool-ui/podcast/voice-preview-button.tsx deleted file mode 100644 index 989b15e0f..000000000 --- a/surfsense_web/components/tool-ui/podcast/voice-preview-button.tsx +++ /dev/null @@ -1,98 +0,0 @@ -"use client"; - -import { Loader2, Play, Square } from "lucide-react"; -import { useEffect, useRef, useState } from "react"; -import { toast } from "sonner"; -import { Button } from "@/components/ui/button"; -import { podcastsApiService } from "@/lib/apis/podcasts-api.service"; - -// Comparing voices means replaying the same samples, so each voice is fetched -// at most once per page lifetime. -const sampleUrls = new Map<string, Promise<string>>(); - -// Overlapping samples are useless for comparison, so only one plays at a time. -let activeAudio: HTMLAudioElement | null = null; -let stopActive: (() => void) | null = null; - -function getSampleUrl(voiceId: string): Promise<string> { - let url = sampleUrls.get(voiceId); - if (!url) { - url = podcastsApiService.previewVoice(voiceId).then((blob) => URL.createObjectURL(blob)); - // A failed fetch must not poison the cache for retries. - url.catch(() => sampleUrls.delete(voiceId)); - sampleUrls.set(voiceId, url); - } - return url; -} - -/** Plays a short sample of `voiceId` so users pick voices by sound. */ -export function VoicePreviewButton({ voiceId }: { voiceId: string }) { - const [state, setState] = useState<"idle" | "loading" | "playing">("idle"); - const mountedRef = useRef(true); - - useEffect(() => { - mountedRef.current = true; - return () => { - mountedRef.current = false; - if (stopActive && activeAudio?.dataset.voiceId === voiceId) { - stopActive(); - } - }; - }, [voiceId]); - - const stop = () => { - if (stopActive) stopActive(); - }; - - const play = async () => { - stop(); - setState("loading"); - try { - const url = await getSampleUrl(voiceId); - if (!mountedRef.current) return; - - const audio = new Audio(url); - audio.dataset.voiceId = voiceId; - activeAudio = audio; - stopActive = () => { - audio.pause(); - activeAudio = null; - stopActive = null; - if (mountedRef.current) setState("idle"); - }; - audio.onended = () => { - if (activeAudio === audio) { - activeAudio = null; - stopActive = null; - } - if (mountedRef.current) setState("idle"); - }; - await audio.play(); - if (mountedRef.current) setState("playing"); - } catch (error) { - if (mountedRef.current) setState("idle"); - toast.error(error instanceof Error ? error.message : "Couldn't play the voice sample"); - } - }; - - const isPlaying = state === "playing"; - - return ( - <Button - type="button" - variant="ghost" - size="icon" - aria-label={isPlaying ? "Stop voice sample" : "Play voice sample"} - disabled={state === "loading"} - onClick={isPlaying ? stop : play} - > - {state === "loading" ? ( - <Loader2 className="size-4 animate-spin" /> - ) : isPlaying ? ( - <Square className="size-4" /> - ) : ( - <Play className="size-4" /> - )} - </Button> - ); -} diff --git a/surfsense_web/contracts/types/auth.types.ts b/surfsense_web/contracts/types/auth.types.ts index b630c461b..29a296c11 100644 --- a/surfsense_web/contracts/types/auth.types.ts +++ b/surfsense_web/contracts/types/auth.types.ts @@ -20,7 +20,8 @@ export const registerRequest = loginRequest.omit({ grant_type: true, username: t export const registerResponse = registerRequest.omit({ password: true }).extend({ id: z.string(), - credit_micros_balance: z.number(), + pages_limit: z.number(), + pages_used: z.number(), }); export type LoginRequest = z.infer<typeof loginRequest>; diff --git a/surfsense_web/contracts/types/inbox.types.ts b/surfsense_web/contracts/types/inbox.types.ts index 94e533809..b4cf01710 100644 --- a/surfsense_web/contracts/types/inbox.types.ts +++ b/surfsense_web/contracts/types/inbox.types.ts @@ -11,7 +11,7 @@ export const inboxItemTypeEnum = z.enum([ "document_processing", "new_mention", "comment_reply", - "insufficient_credits", + "page_limit_exceeded", ]); /** @@ -116,17 +116,15 @@ export const commentReplyMetadata = z.object({ }); /** - * Insufficient credits metadata schema. - * - * ``balance_micros`` / ``required_micros`` are integer micro-USD - * (1_000_000 == $1.00); the UI divides by 1M when displaying. + * Page limit exceeded metadata schema */ -export const insufficientCreditsMetadata = baseInboxItemMetadata.extend({ +export const pageLimitExceededMetadata = baseInboxItemMetadata.extend({ document_name: z.string(), document_type: z.string(), - balance_micros: z.number(), - required_micros: z.number(), - error_type: z.literal("insufficient_credits"), + pages_used: z.number(), + pages_limit: z.number(), + pages_to_add: z.number(), + error_type: z.literal("page_limit_exceeded"), // Navigation target for frontend action_url: z.string(), action_label: z.string(), @@ -142,7 +140,7 @@ export const inboxItemMetadata = z.union([ documentProcessingMetadata, newMentionMetadata, commentReplyMetadata, - insufficientCreditsMetadata, + pageLimitExceededMetadata, baseInboxItemMetadata, ]); @@ -190,9 +188,9 @@ export const commentReplyInboxItem = inboxItem.extend({ metadata: commentReplyMetadata, }); -export const insufficientCreditsInboxItem = inboxItem.extend({ - type: z.literal("insufficient_credits"), - metadata: insufficientCreditsMetadata, +export const pageLimitExceededInboxItem = inboxItem.extend({ + type: z.literal("page_limit_exceeded"), + metadata: pageLimitExceededMetadata, }); // ============================================================================= @@ -343,12 +341,12 @@ export function isCommentReplyMetadata(metadata: unknown): metadata is CommentRe } /** - * Type guard for InsufficientCreditsMetadata + * Type guard for PageLimitExceededMetadata */ -export function isInsufficientCreditsMetadata( +export function isPageLimitExceededMetadata( metadata: unknown -): metadata is InsufficientCreditsMetadata { - return insufficientCreditsMetadata.safeParse(metadata).success; +): metadata is PageLimitExceededMetadata { + return pageLimitExceededMetadata.safeParse(metadata).success; } /** @@ -363,7 +361,7 @@ export function parseInboxItemMetadata( | DocumentProcessingMetadata | NewMentionMetadata | CommentReplyMetadata - | InsufficientCreditsMetadata + | PageLimitExceededMetadata | null { switch (type) { case "connector_indexing": { @@ -386,8 +384,8 @@ export function parseInboxItemMetadata( const result = commentReplyMetadata.safeParse(metadata); return result.success ? result.data : null; } - case "insufficient_credits": { - const result = insufficientCreditsMetadata.safeParse(metadata); + case "page_limit_exceeded": { + const result = pageLimitExceededMetadata.safeParse(metadata); return result.success ? result.data : null; } default: @@ -408,7 +406,7 @@ export type ConnectorDeletionMetadata = z.infer<typeof connectorDeletionMetadata export type DocumentProcessingMetadata = z.infer<typeof documentProcessingMetadata>; export type NewMentionMetadata = z.infer<typeof newMentionMetadata>; export type CommentReplyMetadata = z.infer<typeof commentReplyMetadata>; -export type InsufficientCreditsMetadata = z.infer<typeof insufficientCreditsMetadata>; +export type PageLimitExceededMetadata = z.infer<typeof pageLimitExceededMetadata>; export type InboxItemMetadata = z.infer<typeof inboxItemMetadata>; export type InboxItem = z.infer<typeof inboxItem>; export type ConnectorIndexingInboxItem = z.infer<typeof connectorIndexingInboxItem>; @@ -416,7 +414,7 @@ export type ConnectorDeletionInboxItem = z.infer<typeof connectorDeletionInboxIt export type DocumentProcessingInboxItem = z.infer<typeof documentProcessingInboxItem>; export type NewMentionInboxItem = z.infer<typeof newMentionInboxItem>; export type CommentReplyInboxItem = z.infer<typeof commentReplyInboxItem>; -export type InsufficientCreditsInboxItem = z.infer<typeof insufficientCreditsInboxItem>; +export type PageLimitExceededInboxItem = z.infer<typeof pageLimitExceededInboxItem>; // API Request/Response types export type GetNotificationsRequest = z.infer<typeof getNotificationsRequest>; diff --git a/surfsense_web/contracts/types/incentive-tasks.types.ts b/surfsense_web/contracts/types/incentive-tasks.types.ts index abe91d905..c45121c29 100644 --- a/surfsense_web/contracts/types/incentive-tasks.types.ts +++ b/surfsense_web/contracts/types/incentive-tasks.types.ts @@ -12,8 +12,7 @@ export const incentiveTaskInfo = z.object({ task_type: incentiveTaskTypeEnum, title: z.string(), description: z.string(), - // Reward in micro-USD (1_000_000 == $1.00) credited to the wallet. - credit_micros_reward: z.number(), + pages_reward: z.number(), action_url: z.string(), completed: z.boolean(), completed_at: z.string().nullable(), @@ -24,7 +23,7 @@ export const incentiveTaskInfo = z.object({ */ export const getIncentiveTasksResponse = z.object({ tasks: z.array(incentiveTaskInfo), - total_credit_micros_earned: z.number(), + total_pages_earned: z.number(), }); /** @@ -33,8 +32,8 @@ export const getIncentiveTasksResponse = z.object({ export const completeTaskSuccessResponse = z.object({ success: z.literal(true), message: z.string(), - credit_micros_awarded: z.number(), - new_balance_micros: z.number(), + pages_awarded: z.number(), + new_pages_limit: z.number(), }); /** diff --git a/surfsense_web/contracts/types/podcast.types.ts b/surfsense_web/contracts/types/podcast.types.ts deleted file mode 100644 index e6332d5b2..000000000 --- a/surfsense_web/contracts/types/podcast.types.ts +++ /dev/null @@ -1,126 +0,0 @@ -import { z } from "zod"; - -// ============================================================================= -// Lifecycle — mirror app/podcasts/persistence/enums/podcast_status.py -// ============================================================================= - -export const podcastStatus = z.enum([ - "pending", - "awaiting_brief", - "drafting", - "awaiting_review", - "rendering", - "ready", - "failed", - "cancelled", -]); -export type PodcastStatus = z.infer<typeof podcastStatus>; - -/** - * States waiting on user input before the lifecycle can proceed. The brief is - * the only approval gate; `awaiting_review` survives in the enum for legacy - * rows but is never entered anymore. - */ -export const GATE_STATUSES: ReadonlySet<PodcastStatus> = new Set(["awaiting_brief"]); - -/** - * States from which no further transition is possible. A `ready` episode is - * not terminal: it can be sent back to drafting for regeneration. - */ -export const TERMINAL_STATUSES: ReadonlySet<PodcastStatus> = new Set(["failed", "cancelled"]); - -// ============================================================================= -// Brief (spec) — mirror app/podcasts/schemas/spec.py -// ============================================================================= - -export const speakerRole = z.enum(["host", "cohost", "guest", "expert", "narrator"]); -export type SpeakerRole = z.infer<typeof speakerRole>; - -export const podcastStyle = z.enum([ - "conversational", - "interview", - "debate", - "monologue", - "narrative", -]); -export type PodcastStyle = z.infer<typeof podcastStyle>; - -export const MAX_SPEAKERS = 6; - -export const speakerSpec = z.object({ - slot: z.number().int().min(0), - name: z.string().min(1).max(120), - role: speakerRole, - voice_id: z.string().min(1), -}); -export type SpeakerSpec = z.infer<typeof speakerSpec>; - -export const durationTarget = z.object({ - min_minutes: z.number().int().min(1), - max_minutes: z.number().int().min(1), -}); -export type DurationTarget = z.infer<typeof durationTarget>; - -export const podcastSpec = z - .object({ - language: z.string().min(2), - style: podcastStyle, - speakers: z.array(speakerSpec).min(1).max(MAX_SPEAKERS), - duration: durationTarget, - focus: z.string().max(2000).nullable().optional(), - }) - // Mirrors the backend invariant: one voice is what "monologue" means. - .refine((spec) => spec.style !== "monologue" || spec.speakers.length === 1, { - message: "A monologue has exactly one speaker", - path: ["speakers"], - }); -export type PodcastSpec = z.infer<typeof podcastSpec>; - -// ============================================================================= -// Transcript — mirror app/podcasts/schemas/transcript.py -// ============================================================================= - -export const transcriptTurn = z.object({ - speaker: z.number().int().min(0), - text: z.string().min(1), -}); -export type TranscriptTurn = z.infer<typeof transcriptTurn>; - -export const transcript = z.object({ - turns: z.array(transcriptTurn).min(1), -}); -export type Transcript = z.infer<typeof transcript>; - -// ============================================================================= -// API shapes — mirror app/podcasts/api/schemas.py -// ============================================================================= - -export const voiceOption = z.object({ - voice_id: z.string(), - display_name: z.string(), - language: z.string(), - gender: z.string(), -}); -export type VoiceOption = z.infer<typeof voiceOption>; - -export const updateSpecRequest = z.object({ - spec: podcastSpec, - expected_version: z.number().int().min(1), -}); -export type UpdateSpecRequest = z.infer<typeof updateSpecRequest>; - -export const podcastDetail = z.object({ - id: z.number(), - title: z.string(), - status: podcastStatus, - spec_version: z.number(), - spec: podcastSpec.nullable(), - transcript: transcript.nullable(), - has_audio: z.boolean(), - duration_seconds: z.number().nullable(), - error: z.string().nullable(), - created_at: z.string(), - search_space_id: z.number(), - thread_id: z.number().nullable(), -}); -export type PodcastDetail = z.infer<typeof podcastDetail>; diff --git a/surfsense_web/contracts/types/stripe.types.ts b/surfsense_web/contracts/types/stripe.types.ts index c548a3dd0..35ec0cb17 100644 --- a/surfsense_web/contracts/types/stripe.types.ts +++ b/surfsense_web/contracts/types/stripe.types.ts @@ -1,49 +1,20 @@ import { z } from "zod"; -export const purchaseStatusEnum = z.enum(["pending", "completed", "failed"]); +export const pagePurchaseStatusEnum = z.enum(["pending", "completed", "failed"]); -// --------------------------------------------------------------------------- -// Credit purchases ($1 packs that top up credit_micros_balance) -// --------------------------------------------------------------------------- - -export const createCreditCheckoutSessionRequest = z.object({ - quantity: z.number().int().min(1).max(10_000), +export const createCheckoutSessionRequest = z.object({ + quantity: z.number().int().min(1).max(100), search_space_id: z.number().int().min(1), }); -export const createCreditCheckoutSessionResponse = z.object({ +export const createCheckoutSessionResponse = z.object({ checkout_url: z.string(), }); -// Credit balance availability + records. Unit is integer micro-USD -// (1_000_000 == $1.00); the FE divides by 1M when displaying. -export const creditStripeStatusResponse = z.object({ - credit_buying_enabled: z.boolean(), - credit_micros_balance: z.number().default(0), +export const stripeStatusResponse = z.object({ + page_buying_enabled: z.boolean(), }); -export const creditPurchase = z.object({ - id: z.uuid(), - stripe_checkout_session_id: z.string(), - stripe_payment_intent_id: z.string().nullable(), - quantity: z.number(), - credit_micros_granted: z.number(), - amount_total: z.number().nullable(), - currency: z.string().nullable(), - source: z.string().default("checkout"), - status: purchaseStatusEnum, - completed_at: z.string().nullable(), - created_at: z.string(), -}); - -export const getCreditPurchasesResponse = z.object({ - purchases: z.array(creditPurchase), -}); - -// --------------------------------------------------------------------------- -// Legacy page purchases (read-only history; page buying is removed) -// --------------------------------------------------------------------------- - export const pagePurchase = z.object({ id: z.uuid(), stripe_checkout_session_id: z.string(), @@ -52,7 +23,7 @@ export const pagePurchase = z.object({ pages_granted: z.number(), amount_total: z.number().nullable(), currency: z.string().nullable(), - status: purchaseStatusEnum, + status: pagePurchaseStatusEnum, completed_at: z.string().nullable(), created_at: z.string(), }); @@ -61,59 +32,70 @@ export const getPagePurchasesResponse = z.object({ purchases: z.array(pagePurchase), }); -// Response from /stripe/finalize-checkout (credit purchases only). -export const finalizeCheckoutResponse = z.object({ - status: purchaseStatusEnum, - credit_micros_balance: z.number().default(0), - credit_micros_granted: z.number().nullable().optional(), -}); - -// --------------------------------------------------------------------------- -// Auto-reload (off-session top-up when the balance drops below a threshold) -// All *_micros fields are integer micro-USD (1_000_000 == $1.00). -// --------------------------------------------------------------------------- - -export const autoReloadSettingsResponse = z.object({ - feature_enabled: z.boolean(), - enabled: z.boolean().default(false), - threshold_micros: z.number().nullable(), - amount_micros: z.number().nullable(), - min_amount_micros: z.number(), - has_payment_method: z.boolean().default(false), - failed_at: z.string().nullable(), -}); - -export const updateAutoReloadSettingsRequest = z.object({ - enabled: z.boolean(), - threshold_micros: z.number().int().min(0).nullable().optional(), - amount_micros: z.number().int().min(0).nullable().optional(), -}); - -export const createAutoReloadSetupSessionRequest = z.object({ +// Premium credit purchases +export const createTokenCheckoutSessionRequest = z.object({ + quantity: z.number().int().min(1).max(100), search_space_id: z.number().int().min(1), }); -export const createAutoReloadSetupSessionResponse = z.object({ +export const createTokenCheckoutSessionResponse = z.object({ checkout_url: z.string(), }); -export type AutoReloadSettingsResponse = z.infer<typeof autoReloadSettingsResponse>; -export type UpdateAutoReloadSettingsRequest = z.infer<typeof updateAutoReloadSettingsRequest>; -export type CreateAutoReloadSetupSessionRequest = z.infer< - typeof createAutoReloadSetupSessionRequest ->; -export type CreateAutoReloadSetupSessionResponse = z.infer< - typeof createAutoReloadSetupSessionResponse ->; +// Premium credit balance + purchase records. +// +// The unit is integer micro-USD (1_000_000 == $1.00). The schema names +// kept the ``Token`` prefix for API back-compat with pinned clients; +// the field names below are authoritative. +export const tokenStripeStatusResponse = z.object({ + token_buying_enabled: z.boolean(), + premium_credit_micros_used: z.number().default(0), + premium_credit_micros_limit: z.number().default(0), + premium_credit_micros_remaining: z.number().default(0), +}); -export type PurchaseStatus = z.infer<typeof purchaseStatusEnum>; -export type CreateCreditCheckoutSessionRequest = z.infer<typeof createCreditCheckoutSessionRequest>; -export type CreateCreditCheckoutSessionResponse = z.infer< - typeof createCreditCheckoutSessionResponse ->; -export type CreditStripeStatusResponse = z.infer<typeof creditStripeStatusResponse>; -export type CreditPurchase = z.infer<typeof creditPurchase>; -export type GetCreditPurchasesResponse = z.infer<typeof getCreditPurchasesResponse>; +export const tokenPurchaseStatusEnum = pagePurchaseStatusEnum; + +export const tokenPurchase = z.object({ + id: z.uuid(), + stripe_checkout_session_id: z.string(), + stripe_payment_intent_id: z.string().nullable(), + quantity: z.number(), + credit_micros_granted: z.number(), + amount_total: z.number().nullable(), + currency: z.string().nullable(), + status: tokenPurchaseStatusEnum, + completed_at: z.string().nullable(), + created_at: z.string(), +}); + +export const getTokenPurchasesResponse = z.object({ + purchases: z.array(tokenPurchase), +}); + +// Response from /stripe/finalize-checkout. Either page or token fields +// are populated depending on purchase_type. +export const finalizeCheckoutResponse = z.object({ + purchase_type: z.enum(["page_packs", "premium_tokens"]), + status: pagePurchaseStatusEnum, + pages_limit: z.number().nullable().optional(), + pages_used: z.number().nullable().optional(), + pages_granted: z.number().nullable().optional(), + premium_credit_micros_limit: z.number().nullable().optional(), + premium_credit_micros_used: z.number().nullable().optional(), + premium_credit_micros_granted: z.number().nullable().optional(), +}); + +export type PagePurchaseStatus = z.infer<typeof pagePurchaseStatusEnum>; +export type CreateCheckoutSessionRequest = z.infer<typeof createCheckoutSessionRequest>; +export type CreateCheckoutSessionResponse = z.infer<typeof createCheckoutSessionResponse>; +export type StripeStatusResponse = z.infer<typeof stripeStatusResponse>; export type PagePurchase = z.infer<typeof pagePurchase>; export type GetPagePurchasesResponse = z.infer<typeof getPagePurchasesResponse>; +export type CreateTokenCheckoutSessionRequest = z.infer<typeof createTokenCheckoutSessionRequest>; +export type CreateTokenCheckoutSessionResponse = z.infer<typeof createTokenCheckoutSessionResponse>; +export type TokenStripeStatusResponse = z.infer<typeof tokenStripeStatusResponse>; +export type TokenPurchaseStatus = z.infer<typeof tokenPurchaseStatusEnum>; +export type TokenPurchase = z.infer<typeof tokenPurchase>; +export type GetTokenPurchasesResponse = z.infer<typeof getTokenPurchasesResponse>; export type FinalizeCheckoutResponse = z.infer<typeof finalizeCheckoutResponse>; diff --git a/surfsense_web/contracts/types/user.types.ts b/surfsense_web/contracts/types/user.types.ts index 706656064..85fee49a8 100644 --- a/surfsense_web/contracts/types/user.types.ts +++ b/surfsense_web/contracts/types/user.types.ts @@ -6,7 +6,8 @@ export const user = z.object({ is_active: z.boolean(), is_superuser: z.boolean(), is_verified: z.boolean(), - credit_micros_balance: z.number(), + pages_limit: z.number(), + pages_used: z.number(), display_name: z.string().nullish(), avatar_url: z.string().nullish(), }); diff --git a/surfsense_web/hooks/use-inbox.ts b/surfsense_web/hooks/use-inbox.ts index 860c0e01a..e1070219a 100644 --- a/surfsense_web/hooks/use-inbox.ts +++ b/surfsense_web/hooks/use-inbox.ts @@ -22,7 +22,7 @@ const CATEGORY_TYPES: Record<NotificationCategory, string[]> = { "connector_indexing", "connector_deletion", "document_processing", - "insufficient_credits", + "page_limit_exceeded", ], }; diff --git a/surfsense_web/hooks/use-podcast-live.ts b/surfsense_web/hooks/use-podcast-live.ts deleted file mode 100644 index e0a30e05b..000000000 --- a/surfsense_web/hooks/use-podcast-live.ts +++ /dev/null @@ -1,59 +0,0 @@ -"use client"; - -import { useQuery } from "@rocicorp/zero/react"; -import { useMemo } from "react"; -import { type PodcastSpec, type PodcastStatus, podcastSpec } from "@/contracts/types/podcast.types"; -import { queries } from "@/zero/queries"; - -/** - * Thin live row sourced from Zero's `podcasts` publication. Drives the - * lifecycle UI by push (no polling); heavy fields (transcript, audio) stay on - * REST and are fetched lazily when a gate or the player needs them. - */ -export interface LivePodcast { - id: number; - title: string; - status: PodcastStatus; - spec: PodcastSpec | null; - specVersion: number; - durationSeconds: number | null; - error: string | null; - searchSpaceId: number; - threadId: number | null; -} - -interface UsePodcastLiveResult { - podcast: LivePodcast | undefined; - isLoading: boolean; -} - -export function usePodcastLive(podcastId: number | undefined): UsePodcastLiveResult { - const [row, result] = useQuery(queries.podcasts.byId({ podcastId: podcastId ?? -1 })); - - const podcast = useMemo<LivePodcast | undefined>(() => { - if (!podcastId || !row) return undefined; - return { - id: row.id, - title: row.title, - status: row.status as PodcastStatus, - spec: parseSpec(row.spec), - specVersion: row.specVersion, - durationSeconds: row.durationSeconds ?? null, - error: row.error ?? null, - searchSpaceId: row.searchSpaceId, - threadId: row.threadId ?? null, - }; - }, [podcastId, row]); - - // Pre-hydration window: no row AND Zero hasn't confirmed completeness yet. - const isLoading = !!podcastId && !row && result.type !== "complete"; - - return { podcast, isLoading }; -} - -/** The JSONB column holds the snake_case spec; reject anything malformed. */ -function parseSpec(raw: unknown): PodcastSpec | null { - if (raw == null) return null; - const parsed = podcastSpec.safeParse(raw); - return parsed.success ? parsed.data : null; -} diff --git a/surfsense_web/lib/apis/podcasts-api.service.ts b/surfsense_web/lib/apis/podcasts-api.service.ts deleted file mode 100644 index bd7bb784e..000000000 --- a/surfsense_web/lib/apis/podcasts-api.service.ts +++ /dev/null @@ -1,69 +0,0 @@ -import { z } from "zod"; -import { - type PodcastSpec, - podcastDetail, - updateSpecRequest, - voiceOption, -} from "@/contracts/types/podcast.types"; -import { ValidationError } from "../error"; -import { baseApiService } from "./base-api.service"; - -const BASE = "/api/v1/podcasts"; - -const voiceOptionList = z.array(voiceOption); - -class PodcastsApiService { - // Full state including the deserialized brief and transcript; thin lifecycle - // fields (status, spec, spec_version) also arrive live via Zero. - getDetail = async (podcastId: number) => { - return baseApiService.get(`${BASE}/${podcastId}`, podcastDetail); - }; - - // Guarded by the version the caller last saw; the backend answers 409 when - // the brief changed underneath them. - updateSpec = async (podcastId: number, spec: PodcastSpec, expectedVersion: number) => { - const parsed = updateSpecRequest.safeParse({ spec, expected_version: expectedVersion }); - if (!parsed.success) { - throw new ValidationError( - `Invalid request: ${parsed.error.issues.map((i) => i.message).join(", ")}` - ); - } - return baseApiService.patch(`${BASE}/${podcastId}/spec`, podcastDetail, { - body: parsed.data, - }); - }; - - approveBrief = async (podcastId: number) => { - return baseApiService.post(`${BASE}/${podcastId}/brief/approve`, podcastDetail); - }; - - // Reopens the brief gate; the transcript and audio are replaced once the - // user re-approves. - regenerate = async (podcastId: number) => { - return baseApiService.post(`${BASE}/${podcastId}/transcript/regenerate`, podcastDetail); - }; - - // Backs out of a regeneration: the podcast returns to ready with its - // existing audio untouched. 409 when there is no episode to fall back to. - revertRegeneration = async (podcastId: number) => { - return baseApiService.post(`${BASE}/${podcastId}/regenerate/revert`, podcastDetail); - }; - - // Only for podcasts that have produced nothing yet; once an episode - // exists the backend refuses (409) and revertRegeneration is the way back. - cancel = async (podcastId: number) => { - return baseApiService.post(`${BASE}/${podcastId}/cancel`, podcastDetail); - }; - - listVoices = async (language?: string) => { - const qs = language ? `?${new URLSearchParams({ language })}` : ""; - return baseApiService.get(`${BASE}/voices${qs}`, voiceOptionList); - }; - - // A short audio sample of a voice, cached server-side per voice. - previewVoice = async (voiceId: string) => { - return baseApiService.getBlob(`${BASE}/voices/${encodeURIComponent(voiceId)}/preview`); - }; -} - -export const podcastsApiService = new PodcastsApiService(); diff --git a/surfsense_web/lib/apis/stripe-api.service.ts b/surfsense_web/lib/apis/stripe-api.service.ts index b2f5698fb..f119fbf6a 100644 --- a/surfsense_web/lib/apis/stripe-api.service.ts +++ b/surfsense_web/lib/apis/stripe-api.service.ts @@ -1,50 +1,64 @@ import { - type AutoReloadSettingsResponse, - autoReloadSettingsResponse, - type CreateAutoReloadSetupSessionRequest, - type CreateAutoReloadSetupSessionResponse, - type CreateCreditCheckoutSessionRequest, - type CreateCreditCheckoutSessionResponse, - type CreditStripeStatusResponse, - createAutoReloadSetupSessionResponse, - createCreditCheckoutSessionResponse, - creditStripeStatusResponse, + type CreateCheckoutSessionRequest, + type CreateCheckoutSessionResponse, + type CreateTokenCheckoutSessionRequest, + type CreateTokenCheckoutSessionResponse, + createCheckoutSessionResponse, + createTokenCheckoutSessionResponse, type FinalizeCheckoutResponse, finalizeCheckoutResponse, - type GetCreditPurchasesResponse, type GetPagePurchasesResponse, - getCreditPurchasesResponse, + type GetTokenPurchasesResponse, getPagePurchasesResponse, - type UpdateAutoReloadSettingsRequest, + getTokenPurchasesResponse, + type StripeStatusResponse, + stripeStatusResponse, + type TokenStripeStatusResponse, + tokenStripeStatusResponse, } from "@/contracts/types/stripe.types"; import { baseApiService } from "./base-api.service"; class StripeApiService { - createCreditCheckoutSession = async ( - request: CreateCreditCheckoutSessionRequest - ): Promise<CreateCreditCheckoutSessionResponse> => { + createCheckoutSession = async ( + request: CreateCheckoutSessionRequest + ): Promise<CreateCheckoutSessionResponse> => { return baseApiService.post( - "/api/v1/stripe/create-credit-checkout-session", - createCreditCheckoutSessionResponse, + "/api/v1/stripe/create-checkout-session", + createCheckoutSessionResponse, + { + body: request, + } + ); + }; + + getPurchases = async (): Promise<GetPagePurchasesResponse> => { + return baseApiService.get("/api/v1/stripe/purchases", getPagePurchasesResponse); + }; + + getStatus = async (): Promise<StripeStatusResponse> => { + return baseApiService.get("/api/v1/stripe/status", stripeStatusResponse); + }; + + createTokenCheckoutSession = async ( + request: CreateTokenCheckoutSessionRequest + ): Promise<CreateTokenCheckoutSessionResponse> => { + return baseApiService.post( + "/api/v1/stripe/create-token-checkout-session", + createTokenCheckoutSessionResponse, { body: request } ); }; - getCreditStatus = async (): Promise<CreditStripeStatusResponse> => { - return baseApiService.get("/api/v1/stripe/credit-status", creditStripeStatusResponse); + getTokenStatus = async (): Promise<TokenStripeStatusResponse> => { + return baseApiService.get("/api/v1/stripe/token-status", tokenStripeStatusResponse); }; - getCreditPurchases = async (): Promise<GetCreditPurchasesResponse> => { - return baseApiService.get("/api/v1/stripe/credit-purchases", getCreditPurchasesResponse); - }; - - /** Legacy page-purchase history (read-only; page buying is removed). */ - getPagePurchases = async (): Promise<GetPagePurchasesResponse> => { - return baseApiService.get("/api/v1/stripe/purchases", getPagePurchasesResponse); + getTokenPurchases = async (): Promise<GetTokenPurchasesResponse> => { + return baseApiService.get("/api/v1/stripe/token-purchases", getTokenPurchasesResponse); }; /** - * Synchronously fulfil a credit checkout session from the success page. + * Synchronously fulfil a checkout session from the success page. * * Solves the race where the user lands on /purchase-success before * Stripe's checkout.session.completed webhook arrives. Idempotent — @@ -56,30 +70,6 @@ class StripeApiService { finalizeCheckoutResponse ); }; - - // --- Auto-reload -------------------------------------------------------- - - getAutoReloadSettings = async (): Promise<AutoReloadSettingsResponse> => { - return baseApiService.get("/api/v1/stripe/auto-reload", autoReloadSettingsResponse); - }; - - updateAutoReloadSettings = async ( - request: UpdateAutoReloadSettingsRequest - ): Promise<AutoReloadSettingsResponse> => { - return baseApiService.put("/api/v1/stripe/auto-reload", autoReloadSettingsResponse, { - body: request, - }); - }; - - createAutoReloadSetupSession = async ( - request: CreateAutoReloadSetupSessionRequest - ): Promise<CreateAutoReloadSetupSessionResponse> => { - return baseApiService.post( - "/api/v1/stripe/auto-reload/setup", - createAutoReloadSetupSessionResponse, - { body: request } - ); - }; } export const stripeApiService = new StripeApiService(); diff --git a/surfsense_web/lib/chat/podcast-state.ts b/surfsense_web/lib/chat/podcast-state.ts new file mode 100644 index 000000000..061a89b63 --- /dev/null +++ b/surfsense_web/lib/chat/podcast-state.ts @@ -0,0 +1,73 @@ +/** + * Module-level state for tracking active podcast generation. + * Used by the new-chat adapter to prevent duplicate podcast requests. + */ + +type PodcastStateListener = (isGenerating: boolean) => void; + +let _activePodcastTaskId: string | null = null; +const _listeners: Set<PodcastStateListener> = new Set(); + +/** + * Check if a podcast is currently being generated + */ +export function isPodcastGenerating(): boolean { + return _activePodcastTaskId !== null; +} + +/** + * Get the active podcast task ID + */ +export function getActivePodcastTaskId(): string | null { + return _activePodcastTaskId; +} + +/** + * Set the active podcast task ID (when podcast generation starts) + */ +export function setActivePodcastTaskId(taskId: string): void { + _activePodcastTaskId = taskId; + notifyListeners(); +} + +/** + * Clear the active podcast task ID (when podcast generation completes or errors) + */ +export function clearActivePodcastTaskId(): void { + _activePodcastTaskId = null; + notifyListeners(); +} + +/** + * Subscribe to podcast state changes + */ +export function subscribeToPodcastState(listener: PodcastStateListener): () => void { + _listeners.add(listener); + return () => { + _listeners.delete(listener); + }; +} + +function notifyListeners(): void { + const isGenerating = _activePodcastTaskId !== null; + for (const listener of _listeners) { + listener(isGenerating); + } +} + +/** + * Check if a message looks like a podcast request + */ +export function looksLikePodcastRequest(message: string): boolean { + const podcastPatterns = [ + /\bpodcast\b/i, + /\bcreate.*podcast\b/i, + /\bgenerate.*podcast\b/i, + /\bmake.*podcast\b/i, + /\bturn.*into.*podcast\b/i, + /\bpodcast.*about\b/i, + /\bgive.*podcast\b/i, + ]; + + return podcastPatterns.some((pattern) => pattern.test(message)); +} diff --git a/surfsense_web/lib/posthog/events.ts b/surfsense_web/lib/posthog/events.ts index 1f8875a6d..4dc644d5e 100644 --- a/surfsense_web/lib/posthog/events.ts +++ b/surfsense_web/lib/posthog/events.ts @@ -569,10 +569,10 @@ export function trackIncentivePageViewed() { safeCapture("incentive_page_viewed"); } -export function trackIncentiveTaskCompleted(taskType: string, creditMicrosRewarded: number) { +export function trackIncentiveTaskCompleted(taskType: string, pagesRewarded: number) { safeCapture("incentive_task_completed", { task_type: taskType, - credit_micros_rewarded: creditMicrosRewarded, + pages_rewarded: pagesRewarded, }); } diff --git a/surfsense_web/package.json b/surfsense_web/package.json index ea8425be5..2e999b42c 100644 --- a/surfsense_web/package.json +++ b/surfsense_web/package.json @@ -1,6 +1,6 @@ { "name": "surfsense_web", - "version": "0.0.28", + "version": "0.0.27", "private": true, "packageManager": "pnpm@10.26.0", "description": "SurfSense Frontend", @@ -31,8 +31,8 @@ "dependencies": { "@ai-sdk/react": "^1.2.12", "@ariakit/react": "^0.4.21", - "@assistant-ui/react": "^0.14.14", - "@assistant-ui/react-markdown": "^0.14.1", + "@assistant-ui/react": "^0.12.19", + "@assistant-ui/react-markdown": "^0.12.6", "@babel/standalone": "^7.29.2", "@hookform/resolvers": "^5.2.2", "@marsidev/react-turnstile": "^1.5.0", diff --git a/surfsense_web/pnpm-lock.yaml b/surfsense_web/pnpm-lock.yaml index 4a5b0b5d0..652eff8f5 100644 --- a/surfsense_web/pnpm-lock.yaml +++ b/surfsense_web/pnpm-lock.yaml @@ -15,11 +15,11 @@ importers: specifier: ^0.4.21 version: 0.4.21(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@assistant-ui/react': - specifier: ^0.14.14 - version: 0.14.14(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)) + specifier: ^0.12.19 + version: 0.12.19(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)) '@assistant-ui/react-markdown': - specifier: ^0.14.1 - version: 0.14.1(@assistant-ui/react@0.14.14(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)))(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + specifier: ^0.12.6 + version: 0.12.6(@assistant-ui/react@0.12.19(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)))(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@babel/standalone': specifier: ^7.29.2 version: 7.29.2 @@ -498,13 +498,13 @@ packages: react: ^17.0.0 || ^18.0.0 || ^19.0.0 react-dom: ^17.0.0 || ^18.0.0 || ^19.0.0 - '@assistant-ui/core@0.2.10': - resolution: {integrity: sha512-0YyqlpZgg1Hoaq2X4jHAaMKXg+lGniLygNt1KrGFTPgbxeo8ZStRjWyyG2xIl+zlFKHiKCGHzflUHvlJi4IurA==} + '@assistant-ui/core@0.1.7': + resolution: {integrity: sha512-219T42ihVOicbJXZLWgD2CW5Bylg9Nk7geC331X4RfJxTDYlm2zIjViGlGaqfj6URXBp6kMulO2BTUrHGmAvdw==} peerDependencies: - '@assistant-ui/store': ^0.2.13 - '@assistant-ui/tap': ^0.5.14 + '@assistant-ui/store': ^0.2.3 + '@assistant-ui/tap': ^0.5.3 '@types/react': '*' - assistant-cloud: ^0.1.31 + assistant-cloud: ^0.1.22 react: ^18 || ^19 zustand: ^5.0.11 peerDependenciesMeta: @@ -517,18 +517,18 @@ packages: zustand: optional: true - '@assistant-ui/react-markdown@0.14.1': - resolution: {integrity: sha512-Q1S66rLS0J+b7jUjKrPGryLZsdg8v9NX/QdSTRmOCi5H6smWHfgMYvDypQ4BHn+4Tc+m3ggLKFPCgBV6t6iLhQ==} + '@assistant-ui/react-markdown@0.12.6': + resolution: {integrity: sha512-utJqsdDXB3UVZfOa3ErLpaTHraeXkDshR0D34shWdTHrmLyx9e/HypTu4+BgiSsxS+ME6t9WO9M3VeGDprfUcQ==} peerDependencies: - '@assistant-ui/react': ^0.14.8 + '@assistant-ui/react': ^0.12.19 '@types/react': '*' react: ^18 || ^19 peerDependenciesMeta: '@types/react': optional: true - '@assistant-ui/react@0.14.14': - resolution: {integrity: sha512-qS7YJewwFbmhs+yte56ZnO9jIOK+8hKo7mOK3cKDcCndn+jGSWTJmoNVIYQgMpB2JYIJ/SKZD+LeWSR6K3LL5g==} + '@assistant-ui/react@0.12.19': + resolution: {integrity: sha512-scAf0o8cwjuHT9Y44EFGXcE2y6BSmpeMvt0NxOn8+Y/HBlNttQMLNvrM0p2AjacXCUufagiafAnWybzBV3nKEQ==} peerDependencies: '@types/react': '*' '@types/react-dom': '*' @@ -540,18 +540,18 @@ packages: '@types/react-dom': optional: true - '@assistant-ui/store@0.2.13': - resolution: {integrity: sha512-7NL6HWMBxe1ndLWO4kHkjQ0Syyc0D/Aj+zxdpcy4yrplG71X04CzFimMBBSQAk+AnGBf+d96D7cuUZdjHkTavg==} + '@assistant-ui/store@0.2.3': + resolution: {integrity: sha512-daStbgSQiX7+csqK6Cvo7A8p8UZkTCSMxBHxbhJvwrlVbp7BRJWTxq3U3rpTkSGIar23SXIyVRRfXU8VW7pswA==} peerDependencies: - '@assistant-ui/tap': ^0.5.14 + '@assistant-ui/tap': ^0.5.3 '@types/react': '*' react: ^18 || ^19 peerDependenciesMeta: '@types/react': optional: true - '@assistant-ui/tap@0.5.14': - resolution: {integrity: sha512-SAy0ip8nKo72U8K9MuU7gYUR4tzoIi6k+HAQgev3zA/sWN7hr/QDDUTblrn5QB9Y/yycRiq8s98WD1vnDy8WMQ==} + '@assistant-ui/tap@0.5.3': + resolution: {integrity: sha512-wy06ksqF2LfFxe4JXy31Ns89N/be1Dy3c+mG363cFHFp3CbLkRu8CrCN2SQSgCkXt628E+D8QyzqdBcl9kD4NQ==} peerDependencies: '@types/react': '*' react: ^18 || ^19 @@ -5170,19 +5170,11 @@ packages: resolution: {integrity: sha512-BNoCY6SXXPQ7gF2opIP4GBE+Xw7U+pHMYKuzjgCN3GwiaIR09UUeKfheyIry77QtrCBlC0KK0q5/TER/tYh3PQ==} engines: {node: '>= 0.4'} - assistant-cloud@0.1.31: - resolution: {integrity: sha512-YBLc79w2EFD/6YjvcZrperpZ+B3TQ9LZ39AbjfcnbIJiSXYAs8cDH+mgy1GrfJBq47nhGaTVEf7ajv+hk084eA==} + assistant-cloud@0.1.22: + resolution: {integrity: sha512-AEE9shV+oFrGDv/MRTRERctNKpIYS0n34UpAQXXICiOkSWD6QZnS1ljLqruFko7fJoT5CIWq8dNeJWdzQLTBLg==} - assistant-stream@0.3.20: - resolution: {integrity: sha512-CniC84epmE9JrMSDzlZVWJ13O5rYbjoqEzh0jT+QfsrR07LBls42DMJ60XNxKXm8Hrn6MHSZcxqBUqwXRtoutA==} - peerDependencies: - ioredis: ^5.10.1 - redis: ^5.12.1 - peerDependenciesMeta: - ioredis: - optional: true - redis: - optional: true + assistant-stream@0.3.6: + resolution: {integrity: sha512-NdtSRrQfWCDA/aqQ1xhobf/xnhuMZkhFAw9xzAt5iAoL3ouxVXOowSRN87OL4MYBQEvqtcjw9/CE6YcsXoBtuw==} ast-types-flow@0.0.8: resolution: {integrity: sha512-OH/2E5Fg20h2aPrbe+QL8JZQFko0YZaF+j4mnQ7BGhfavO7OpSLa8a0y9sBwomHdSbkhTS8TQNayBfnW5DwbvQ==} @@ -8130,9 +8122,6 @@ packages: safe-buffer@5.2.1: resolution: {integrity: sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==} - safe-content-frame@0.0.20: - resolution: {integrity: sha512-saE3fBeGWOsi04PzTUaRi6RsBIjDYrZX4KzgIZUjbq3xQeOKYMcW1DeTb573Zyx1ggCDVJKoD/THchblISwjiQ==} - safe-push-apply@1.0.0: resolution: {integrity: sha512-iKE9w/Z7xCzUMIZqdBsp6pEQvwuEebH4vdpjcDWnyzaI6yl6O9FHvVpmGelvEHNsoY6wGblkxR6Zty/h00WiSA==} engines: {node: '>= 0.4'} @@ -8905,9 +8894,6 @@ packages: zod@4.3.6: resolution: {integrity: sha512-rftlrkhHZOcjDwkGlnUtZZkvaPHCsDATp4pGpuOOMDaTdDDXF91wuVDJoWoPsKX/3YPQ5fHuF3STjcYyKr+Qhg==} - zod@4.4.3: - resolution: {integrity: sha512-ytENFjIJFl2UwYglde2jchW2Hwm4GJFLDiSXWdTrJQBIN9Fcyp7n4DhxJEiWNAJMV1/BqWfW/kkg71UDcHJyTQ==} - zustand-x@6.2.1: resolution: {integrity: sha512-y3nQMQNx3BORY95vpuodJvh/8AqQu++S3q6mJYBSo1J0Q168Sy+FatqER658YESDqv2bwviXcIT3bgl/Ip6M5g==} peerDependencies: @@ -8931,24 +8917,6 @@ packages: use-sync-external-store: optional: true - zustand@5.0.14: - resolution: {integrity: sha512-/8tAspM5LMPr28b3fwLYrtdj77ECpfZviaP75CMTnwO8ISyaE4GDIG/9rDDYq/cH9D2Xw2A2RXglLInmVBQB/g==} - engines: {node: '>=12.20.0'} - peerDependencies: - '@types/react': '>=18.0.0' - immer: '>=9.0.6' - react: '>=18.0.0' - use-sync-external-store: '>=1.2.0' - peerDependenciesMeta: - '@types/react': - optional: true - immer: - optional: true - react: - optional: true - use-sync-external-store: - optional: true - zwitch@2.0.4: resolution: {integrity: sha512-bXE4cR/kVZhKZX/RjPEflHaKVhUVl85noU3v6b8apfQEc1x4A+zBxjZ4lN8LqGd6WZ3dl98pY4o717VFmoPp+A==} @@ -9000,24 +8968,21 @@ snapshots: react: 19.2.4 react-dom: 19.2.4(react@19.2.4) - '@assistant-ui/core@0.2.10(@assistant-ui/store@0.2.13(@assistant-ui/tap@0.5.14(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(react@19.2.4))(@assistant-ui/tap@0.5.14(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(assistant-cloud@0.1.31)(react@19.2.4)(zustand@5.0.14(@types/react@19.2.14)(immer@10.2.0)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)))': + '@assistant-ui/core@0.1.7(@assistant-ui/store@0.2.3(@assistant-ui/tap@0.5.3(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(react@19.2.4))(@assistant-ui/tap@0.5.3(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(assistant-cloud@0.1.22)(react@19.2.4)(zustand@5.0.11(@types/react@19.2.14)(immer@10.2.0)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)))': dependencies: - '@assistant-ui/store': 0.2.13(@assistant-ui/tap@0.5.14(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(react@19.2.4) - '@assistant-ui/tap': 0.5.14(@types/react@19.2.14)(react@19.2.4) - assistant-stream: 0.3.20 - nanoid: 5.1.11 + '@assistant-ui/store': 0.2.3(@assistant-ui/tap@0.5.3(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(react@19.2.4) + '@assistant-ui/tap': 0.5.3(@types/react@19.2.14)(react@19.2.4) + assistant-stream: 0.3.6 + nanoid: 5.1.7 optionalDependencies: '@types/react': 19.2.14 - assistant-cloud: 0.1.31 + assistant-cloud: 0.1.22 react: 19.2.4 - zustand: 5.0.14(@types/react@19.2.14)(immer@10.2.0)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)) - transitivePeerDependencies: - - ioredis - - redis + zustand: 5.0.11(@types/react@19.2.14)(immer@10.2.0)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)) - '@assistant-ui/react-markdown@0.14.1(@assistant-ui/react@0.14.14(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)))(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)': + '@assistant-ui/react-markdown@0.12.6(@assistant-ui/react@0.12.19(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)))(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)': dependencies: - '@assistant-ui/react': 0.14.14(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)) + '@assistant-ui/react': 0.12.19(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)) '@radix-ui/react-primitive': 2.1.4(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@radix-ui/react-use-callback-ref': 1.1.1(@types/react@19.2.14)(react@19.2.4) classnames: 2.5.1 @@ -9030,45 +8995,42 @@ snapshots: - react-dom - supports-color - '@assistant-ui/react@0.14.14(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4))': + '@assistant-ui/react@0.12.19(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(immer@10.2.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4))': dependencies: - '@assistant-ui/core': 0.2.10(@assistant-ui/store@0.2.13(@assistant-ui/tap@0.5.14(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(react@19.2.4))(@assistant-ui/tap@0.5.14(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(assistant-cloud@0.1.31)(react@19.2.4)(zustand@5.0.14(@types/react@19.2.14)(immer@10.2.0)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4))) - '@assistant-ui/store': 0.2.13(@assistant-ui/tap@0.5.14(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(react@19.2.4) - '@assistant-ui/tap': 0.5.14(@types/react@19.2.14)(react@19.2.4) + '@assistant-ui/core': 0.1.7(@assistant-ui/store@0.2.3(@assistant-ui/tap@0.5.3(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(react@19.2.4))(@assistant-ui/tap@0.5.3(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(assistant-cloud@0.1.22)(react@19.2.4)(zustand@5.0.11(@types/react@19.2.14)(immer@10.2.0)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4))) + '@assistant-ui/store': 0.2.3(@assistant-ui/tap@0.5.3(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(react@19.2.4) + '@assistant-ui/tap': 0.5.3(@types/react@19.2.14)(react@19.2.4) '@radix-ui/primitive': 1.1.3 '@radix-ui/react-compose-refs': 1.1.2(@types/react@19.2.14)(react@19.2.4) '@radix-ui/react-context': 1.1.3(@types/react@19.2.14)(react@19.2.4) '@radix-ui/react-primitive': 2.1.4(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@radix-ui/react-use-callback-ref': 1.1.1(@types/react@19.2.14)(react@19.2.4) '@radix-ui/react-use-escape-keydown': 1.1.1(@types/react@19.2.14)(react@19.2.4) - assistant-cloud: 0.1.31 - assistant-stream: 0.3.20 - nanoid: 5.1.11 + assistant-cloud: 0.1.22 + assistant-stream: 0.3.6 + nanoid: 5.1.7 radix-ui: 1.4.3(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) react: 19.2.4 react-dom: 19.2.4(react@19.2.4) react-textarea-autosize: 8.5.9(@types/react@19.2.14)(react@19.2.4) - safe-content-frame: 0.0.20 - zod: 4.4.3 - zustand: 5.0.14(@types/react@19.2.14)(immer@10.2.0)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)) + zod: 4.3.6 + zustand: 5.0.11(@types/react@19.2.14)(immer@10.2.0)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)) optionalDependencies: '@types/react': 19.2.14 '@types/react-dom': 19.2.3(@types/react@19.2.14) transitivePeerDependencies: - immer - - ioredis - - redis - use-sync-external-store - '@assistant-ui/store@0.2.13(@assistant-ui/tap@0.5.14(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(react@19.2.4)': + '@assistant-ui/store@0.2.3(@assistant-ui/tap@0.5.3(@types/react@19.2.14)(react@19.2.4))(@types/react@19.2.14)(react@19.2.4)': dependencies: - '@assistant-ui/tap': 0.5.14(@types/react@19.2.14)(react@19.2.4) + '@assistant-ui/tap': 0.5.3(@types/react@19.2.14)(react@19.2.4) react: 19.2.4 use-effect-event: 2.0.3(react@19.2.4) optionalDependencies: '@types/react': 19.2.14 - '@assistant-ui/tap@0.5.14(@types/react@19.2.14)(react@19.2.4)': + '@assistant-ui/tap@0.5.3(@types/react@19.2.14)(react@19.2.4)': optionalDependencies: '@types/react': 19.2.14 react: 19.2.4 @@ -13829,17 +13791,14 @@ snapshots: get-intrinsic: 1.3.0 is-array-buffer: 3.0.5 - assistant-cloud@0.1.31: + assistant-cloud@0.1.22: dependencies: - assistant-stream: 0.3.20 - transitivePeerDependencies: - - ioredis - - redis + assistant-stream: 0.3.6 - assistant-stream@0.3.20: + assistant-stream@0.3.6: dependencies: '@standard-schema/spec': 1.1.0 - nanoid: 5.1.11 + nanoid: 5.1.7 secure-json-parse: 4.1.0 ast-types-flow@0.0.8: {} @@ -17494,8 +17453,6 @@ snapshots: safe-buffer@5.2.1: {} - safe-content-frame@0.0.20: {} - safe-push-apply@1.0.0: dependencies: es-errors: 1.3.0 @@ -18362,8 +18319,6 @@ snapshots: zod@4.3.6: {} - zod@4.4.3: {} - zustand-x@6.2.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(scheduler@0.27.0)(zustand@5.0.11(@types/react@19.2.14)(immer@10.2.0)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4))): dependencies: immer: 10.2.0 @@ -18385,11 +18340,4 @@ snapshots: react: 19.2.4 use-sync-external-store: 1.6.0(react@19.2.4) - zustand@5.0.14(@types/react@19.2.14)(immer@10.2.0)(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)): - optionalDependencies: - '@types/react': 19.2.14 - immer: 10.2.0 - react: 19.2.4 - use-sync-external-store: 1.6.0(react@19.2.4) - zwitch@2.0.4: {} diff --git a/surfsense_web/zero/queries/index.ts b/surfsense_web/zero/queries/index.ts index 45df8fa98..fe711f5d3 100644 --- a/surfsense_web/zero/queries/index.ts +++ b/surfsense_web/zero/queries/index.ts @@ -4,7 +4,6 @@ import { chatSessionQueries, commentQueries, messageQueries } from "./chat"; import { connectorQueries, documentQueries } from "./documents"; import { folderQueries } from "./folders"; import { notificationQueries } from "./inbox"; -import { podcastQueries } from "./podcasts"; import { userQueries } from "./user"; export const queries = defineQueries({ @@ -17,5 +16,4 @@ export const queries = defineQueries({ chatSession: chatSessionQueries, user: userQueries, automationRuns: automationRunQueries, - podcasts: podcastQueries, }); diff --git a/surfsense_web/zero/queries/podcasts.ts b/surfsense_web/zero/queries/podcasts.ts deleted file mode 100644 index 5298534dd..000000000 --- a/surfsense_web/zero/queries/podcasts.ts +++ /dev/null @@ -1,12 +0,0 @@ -import { defineQuery } from "@rocicorp/zero"; -import { z } from "zod"; -import { zql } from "../schema/index"; - -export const podcastQueries = { - bySpace: defineQuery(z.object({ searchSpaceId: z.number() }), ({ args: { searchSpaceId } }) => - zql.podcasts.where("searchSpaceId", searchSpaceId).orderBy("createdAt", "desc") - ), - byId: defineQuery(z.object({ podcastId: z.number() }), ({ args: { podcastId } }) => - zql.podcasts.where("id", podcastId).one() - ), -}; diff --git a/surfsense_web/zero/schema/index.ts b/surfsense_web/zero/schema/index.ts index d1187ddab..d6731e371 100644 --- a/surfsense_web/zero/schema/index.ts +++ b/surfsense_web/zero/schema/index.ts @@ -4,7 +4,6 @@ import { chatCommentTable, chatSessionStateTable, newChatMessageTable } from "./ import { documentTable, searchSourceConnectorTable } from "./documents"; import { folderTable } from "./folders"; import { notificationTable } from "./inbox"; -import { podcastTable } from "./podcasts"; import { userTable } from "./user"; const chatCommentRelationships = relationships(chatCommentTable, ({ one }) => ({ @@ -39,7 +38,6 @@ export const schema = createSchema({ chatSessionStateTable, userTable, automationRunTable, - podcastTable, ], relationships: [chatCommentRelationships, newChatMessageRelationships], }); diff --git a/surfsense_web/zero/schema/podcasts.ts b/surfsense_web/zero/schema/podcasts.ts deleted file mode 100644 index d473d776f..000000000 --- a/surfsense_web/zero/schema/podcasts.ts +++ /dev/null @@ -1,19 +0,0 @@ -import { json, number, string, table } from "@rocicorp/zero"; - -// Mirrors PODCAST_COLS in the backend zero_publication. status drives the -// lifecycle UI by push; spec is the reviewable brief. The bulky source_content -// and transcript are intentionally not published and are fetched over REST. -export const podcastTable = table("podcasts") - .columns({ - id: number(), - title: string(), - status: string(), - spec: json().optional(), - specVersion: number().from("spec_version"), - durationSeconds: number().optional().from("duration_seconds"), - error: string().optional(), - searchSpaceId: number().from("search_space_id"), - threadId: number().optional().from("thread_id"), - createdAt: number().from("created_at"), - }) - .primaryKey("id"); diff --git a/surfsense_web/zero/schema/user.ts b/surfsense_web/zero/schema/user.ts index 3b6c3ec92..f483fa9b4 100644 --- a/surfsense_web/zero/schema/user.ts +++ b/surfsense_web/zero/schema/user.ts @@ -3,16 +3,18 @@ import { number, string, table } from "@rocicorp/zero"; /** * Live-meter slice of the ``user`` table replicated through Zero. * - * ``creditMicrosBalance`` is stored as integer micro-USD (1_000_000 == $1.00); - * UI consumers divide by 1M when displaying and clamp at $0.00 (the balance can - * dip slightly negative when actual cost exceeds the pre-charge estimate). - * Sensitive fields (email, hashed_password, oauth, etc.) are intentionally - * omitted via the Postgres column-list publication so they never enter WAL - * replication. + * ``premiumCreditMicrosLimit`` / ``premiumCreditMicrosUsed`` are stored + * as integer micro-USD (1_000_000 == $1.00). UI consumers divide by 1M + * when displaying. Sensitive fields (email, hashed_password, oauth, etc.) + * are intentionally omitted via the Postgres column-list publication so + * they never enter WAL replication. */ export const userTable = table("user") .columns({ id: string(), - creditMicrosBalance: number().from("credit_micros_balance"), + pagesLimit: number().from("pages_limit"), + pagesUsed: number().from("pages_used"), + premiumCreditMicrosLimit: number().from("premium_credit_micros_limit"), + premiumCreditMicrosUsed: number().from("premium_credit_micros_used"), }) .primaryKey("id");