diff --git a/surfsense_backend/tests/unit/observability/test_helpers.py b/surfsense_backend/tests/unit/observability/test_helpers.py index ae60c1939..61cf3d591 100644 --- a/surfsense_backend/tests/unit/observability/test_helpers.py +++ b/surfsense_backend/tests/unit/observability/test_helpers.py @@ -31,7 +31,8 @@ def _disable_otel(monkeypatch: pytest.MonkeyPatch): ("process_file_upload_with_document", "process"), ("process_circleback_meeting", "process"), ("generate_video_presentation", "generate"), - ("generate_content_podcast", "generate"), + ("podcast.draft_transcript", "podcast.draft"), + ("podcast.render_audio", "podcast.render"), ("cleanup_stale_indexing_notifications", "cleanup"), ("reconcile_pending_stripe_page_purchases", "reconcile"), ("reconcile_pending_stripe_token_purchases", "reconcile"), diff --git a/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py b/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py index a5bb3f58a..2342dd8da 100644 --- a/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py +++ b/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py @@ -239,17 +239,18 @@ def test_video_presentation_task_uses_runner_helper() -> None: ) -def test_podcast_task_uses_runner_helper() -> None: - """Symmetric assertion for the podcast task — same root cause, same +def test_podcast_tasks_use_runner_helper() -> None: + """Symmetric assertion for the podcast tasks — same root cause, same fix, same regression risk. """ import inspect - from app.tasks.celery_tasks import podcast_tasks + from app.podcasts.tasks import draft, render - src = inspect.getsource(podcast_tasks) - assert "run_async_celery_task" in src - assert "asyncio.new_event_loop" not in src + for module in (draft, render): + src = inspect.getsource(module) + assert "run_async_celery_task" in src + assert "asyncio.new_event_loop" not in src def test_runner_runs_shutdown_asyncgens_before_close() -> None: diff --git a/surfsense_backend/tests/unit/tasks/test_podcast_billing.py b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py deleted file mode 100644 index 699297df1..000000000 --- a/surfsense_backend/tests/unit/tasks/test_podcast_billing.py +++ /dev/null @@ -1,388 +0,0 @@ -"""Unit tests for podcast Celery task billing integration. - -Validates ``_generate_content_podcast`` correctly wraps -``podcaster_graph.ainvoke`` in a ``billable_call`` envelope, propagates the -search-space owner's billing decision, and degrades cleanly when the -resolver fails or premium credit is exhausted. - -Coverage: - -* Happy-path free config: resolver → ``billable_call`` enters with - ``usage_type='podcast_generation'`` and the configured reserve override, - graph runs, podcast row flips to ``READY``. -* Happy-path premium config: same wiring with ``billing_tier='premium'``. -* Quota denial: ``billable_call`` raises ``QuotaInsufficientError`` → - graph is *not* invoked, podcast row flips to ``FAILED``, return dict - carries ``reason='premium_quota_exhausted'``. -* Resolver failure: ``ValueError`` from the resolver → podcast row flips - to ``FAILED``, return dict carries ``reason='billing_resolution_failed'``. -""" - -from __future__ import annotations - -import contextlib -from types import SimpleNamespace -from typing import Any -from uuid import uuid4 - -import pytest - -pytestmark = pytest.mark.unit - - -# --------------------------------------------------------------------------- -# Fakes -# --------------------------------------------------------------------------- - - -class _FakeExecResult: - def __init__(self, obj): - self._obj = obj - - def scalars(self): - return self - - def first(self): - return self._obj - - def filter(self, *_args, **_kwargs): - return self - - -class _FakeSession: - def __init__(self, podcast): - self._podcast = podcast - self.commit_count = 0 - - async def execute(self, _stmt): - return _FakeExecResult(self._podcast) - - async def commit(self): - self.commit_count += 1 - - async def __aenter__(self): - return self - - async def __aexit__(self, *args): - return None - - -class _FakeSessionMaker: - def __init__(self, session: _FakeSession): - self._session = session - - def __call__(self): - return self._session - - -def _make_podcast(podcast_id: int = 7, thread_id: int = 99) -> SimpleNamespace: - """Stand-in for a ``Podcast`` row. Importing ``PodcastStatus`` lazily - inside helpers keeps this fixture cheap.""" - return SimpleNamespace( - id=podcast_id, - title="Test Podcast", - thread_id=thread_id, - status=None, - podcast_transcript=None, - file_location=None, - ) - - -@contextlib.asynccontextmanager -async def _ok_billable_call(**kwargs): - """Stand-in for ``billable_call`` that records its kwargs and yields a - no-op accumulator-shaped object.""" - _CALL_LOG.append(kwargs) - yield SimpleNamespace() - - -_CALL_LOG: list[dict[str, Any]] = [] - - -@contextlib.asynccontextmanager -async def _denying_billable_call(**kwargs): - from app.services.billable_calls import QuotaInsufficientError - - _CALL_LOG.append(kwargs) - raise QuotaInsufficientError( - usage_type=kwargs.get("usage_type", "?"), - used_micros=5_000_000, - limit_micros=5_000_000, - remaining_micros=0, - ) - yield SimpleNamespace() # pragma: no cover — for grammar only - - -@contextlib.asynccontextmanager -async def _settlement_failing_billable_call(**kwargs): - from app.services.billable_calls import BillingSettlementError - - _CALL_LOG.append(kwargs) - yield SimpleNamespace() - raise BillingSettlementError( - usage_type=kwargs.get("usage_type", "?"), - user_id=kwargs["user_id"], - cause=RuntimeError("finalize failed"), - ) - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - - -@pytest.fixture(autouse=True) -def _reset_call_log(): - _CALL_LOG.clear() - yield - _CALL_LOG.clear() - - -@pytest.mark.asyncio -async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch): - """Happy path: free billing tier still wraps the graph call so the - audit row is recorded. Verifies kwargs threading.""" - from app.config import config as app_config - from app.db import PodcastStatus - from app.tasks.celery_tasks import podcast_tasks - - podcast = _make_podcast(podcast_id=7, thread_id=99) - session = _FakeSession(podcast) - monkeypatch.setattr( - podcast_tasks, - "get_celery_session_maker", - lambda: _FakeSessionMaker(session), - ) - - user_id = uuid4() - - async def _fake_resolver(sess, search_space_id, *, thread_id=None): - assert search_space_id == 555 - assert thread_id == 99 - return user_id, "free", "openrouter/some-free-model" - - monkeypatch.setattr( - podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver - ) - monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call) - - async def _fake_graph_invoke(state, config): - return { - "podcast_transcript": [ - SimpleNamespace(speaker_id=0, dialog="Hi"), - SimpleNamespace(speaker_id=1, dialog="Hello"), - ], - "final_podcast_file_path": "/tmp/podcast.wav", - } - - monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) - - result = await podcast_tasks._generate_content_podcast( - podcast_id=7, - source_content="hello world", - search_space_id=555, - user_prompt="make it short", - ) - - assert result["status"] == "ready" - assert result["podcast_id"] == 7 - assert podcast.status == PodcastStatus.READY - assert podcast.file_location == "/tmp/podcast.wav" - - assert len(_CALL_LOG) == 1 - call = _CALL_LOG[0] - assert call["user_id"] == user_id - assert call["search_space_id"] == 555 - assert call["billing_tier"] == "free" - assert call["base_model"] == "openrouter/some-free-model" - assert call["usage_type"] == "podcast_generation" - assert ( - call["quota_reserve_micros_override"] - == app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS - ) - # Background artifact audit rows intentionally omit the TokenUsage.thread_id - # FK to avoid coupling Celery audit commits to an active chat transaction. - assert "thread_id" not in call - assert call["call_details"] == { - "podcast_id": 7, - "title": "Test Podcast", - "thread_id": 99, - } - assert callable(call["billable_session_factory"]) - - -@pytest.mark.asyncio -async def test_billable_call_invoked_with_premium_tier(monkeypatch): - """Premium resolution flows through to ``billable_call`` so the - reserve/finalize path triggers.""" - from app.tasks.celery_tasks import podcast_tasks - - podcast = _make_podcast() - session = _FakeSession(podcast) - monkeypatch.setattr( - podcast_tasks, - "get_celery_session_maker", - lambda: _FakeSessionMaker(session), - ) - - user_id = uuid4() - - async def _fake_resolver(sess, search_space_id, *, thread_id=None): - return user_id, "premium", "gpt-5.4" - - monkeypatch.setattr( - podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver - ) - monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call) - - async def _fake_graph_invoke(state, config): - return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"} - - monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) - - await podcast_tasks._generate_content_podcast( - podcast_id=7, - source_content="hi", - search_space_id=555, - user_prompt=None, - ) - - assert _CALL_LOG[0]["billing_tier"] == "premium" - assert _CALL_LOG[0]["base_model"] == "gpt-5.4" - - -@pytest.mark.asyncio -async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypatch): - """When ``billable_call`` denies the reservation, the graph never - runs and the podcast row flips to FAILED with the documented reason - code.""" - from app.db import PodcastStatus - from app.tasks.celery_tasks import podcast_tasks - - podcast = _make_podcast(podcast_id=8) - session = _FakeSession(podcast) - monkeypatch.setattr( - podcast_tasks, - "get_celery_session_maker", - lambda: _FakeSessionMaker(session), - ) - - async def _fake_resolver(sess, search_space_id, *, thread_id=None): - return uuid4(), "premium", "gpt-5.4" - - monkeypatch.setattr( - podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver - ) - monkeypatch.setattr(podcast_tasks, "billable_call", _denying_billable_call) - - graph_invoked = [] - - async def _fake_graph_invoke(state, config): - graph_invoked.append(True) - return {} - - monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) - - result = await podcast_tasks._generate_content_podcast( - podcast_id=8, - source_content="hi", - search_space_id=555, - user_prompt=None, - ) - - assert result == { - "status": "failed", - "podcast_id": 8, - "reason": "premium_quota_exhausted", - } - assert podcast.status == PodcastStatus.FAILED - assert graph_invoked == [] # Graph never ran on denied reservation. - - -@pytest.mark.asyncio -async def test_billing_settlement_failure_marks_podcast_failed(monkeypatch): - from app.db import PodcastStatus - from app.tasks.celery_tasks import podcast_tasks - - podcast = _make_podcast(podcast_id=10) - session = _FakeSession(podcast) - monkeypatch.setattr( - podcast_tasks, - "get_celery_session_maker", - lambda: _FakeSessionMaker(session), - ) - - async def _fake_resolver(sess, search_space_id, *, thread_id=None): - return uuid4(), "premium", "gpt-5.4" - - monkeypatch.setattr( - podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver - ) - monkeypatch.setattr( - podcast_tasks, "billable_call", _settlement_failing_billable_call - ) - - async def _fake_graph_invoke(state, config): - return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"} - - monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) - - result = await podcast_tasks._generate_content_podcast( - podcast_id=10, - source_content="hi", - search_space_id=555, - user_prompt=None, - ) - - assert result == { - "status": "failed", - "podcast_id": 10, - "reason": "billing_settlement_failed", - } - assert podcast.status == PodcastStatus.FAILED - - -@pytest.mark.asyncio -async def test_resolver_failure_marks_podcast_failed(monkeypatch): - """If the resolver raises (e.g. search-space deleted), the task fails - cleanly without invoking the graph.""" - from app.db import PodcastStatus - from app.tasks.celery_tasks import podcast_tasks - - podcast = _make_podcast(podcast_id=9) - session = _FakeSession(podcast) - monkeypatch.setattr( - podcast_tasks, - "get_celery_session_maker", - lambda: _FakeSessionMaker(session), - ) - - async def _failing_resolver(sess, search_space_id, *, thread_id=None): - raise ValueError("Search space 555 not found") - - monkeypatch.setattr( - podcast_tasks, "_resolve_agent_billing_for_search_space", _failing_resolver - ) - - graph_invoked = [] - - async def _fake_graph_invoke(state, config): - graph_invoked.append(True) - return {} - - monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) - - result = await podcast_tasks._generate_content_podcast( - podcast_id=9, - source_content="hi", - search_space_id=555, - user_prompt=None, - ) - - assert result == { - "status": "failed", - "podcast_id": 9, - "reason": "billing_resolution_failed", - } - assert podcast.status == PodcastStatus.FAILED - assert graph_invoked == []