mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-12 20:45:20 +02:00
test(podcasts): retarget celery and observability tests to new tasks
This commit is contained in:
parent
97ab7a88fd
commit
8f38737ad9
3 changed files with 9 additions and 395 deletions
|
|
@ -31,7 +31,8 @@ def _disable_otel(monkeypatch: pytest.MonkeyPatch):
|
|||
("process_file_upload_with_document", "process"),
|
||||
("process_circleback_meeting", "process"),
|
||||
("generate_video_presentation", "generate"),
|
||||
("generate_content_podcast", "generate"),
|
||||
("podcast.draft_transcript", "podcast.draft"),
|
||||
("podcast.render_audio", "podcast.render"),
|
||||
("cleanup_stale_indexing_notifications", "cleanup"),
|
||||
("reconcile_pending_stripe_page_purchases", "reconcile"),
|
||||
("reconcile_pending_stripe_token_purchases", "reconcile"),
|
||||
|
|
|
|||
|
|
@ -239,17 +239,18 @@ def test_video_presentation_task_uses_runner_helper() -> None:
|
|||
)
|
||||
|
||||
|
||||
def test_podcast_task_uses_runner_helper() -> None:
|
||||
"""Symmetric assertion for the podcast task — same root cause, same
|
||||
def test_podcast_tasks_use_runner_helper() -> None:
|
||||
"""Symmetric assertion for the podcast tasks — same root cause, same
|
||||
fix, same regression risk.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
from app.podcasts.tasks import draft, render
|
||||
|
||||
src = inspect.getsource(podcast_tasks)
|
||||
assert "run_async_celery_task" in src
|
||||
assert "asyncio.new_event_loop" not in src
|
||||
for module in (draft, render):
|
||||
src = inspect.getsource(module)
|
||||
assert "run_async_celery_task" in src
|
||||
assert "asyncio.new_event_loop" not in src
|
||||
|
||||
|
||||
def test_runner_runs_shutdown_asyncgens_before_close() -> None:
|
||||
|
|
|
|||
|
|
@ -1,388 +0,0 @@
|
|||
"""Unit tests for podcast Celery task billing integration.
|
||||
|
||||
Validates ``_generate_content_podcast`` correctly wraps
|
||||
``podcaster_graph.ainvoke`` in a ``billable_call`` envelope, propagates the
|
||||
search-space owner's billing decision, and degrades cleanly when the
|
||||
resolver fails or premium credit is exhausted.
|
||||
|
||||
Coverage:
|
||||
|
||||
* Happy-path free config: resolver → ``billable_call`` enters with
|
||||
``usage_type='podcast_generation'`` and the configured reserve override,
|
||||
graph runs, podcast row flips to ``READY``.
|
||||
* Happy-path premium config: same wiring with ``billing_tier='premium'``.
|
||||
* Quota denial: ``billable_call`` raises ``QuotaInsufficientError`` →
|
||||
graph is *not* invoked, podcast row flips to ``FAILED``, return dict
|
||||
carries ``reason='premium_quota_exhausted'``.
|
||||
* Resolver failure: ``ValueError`` from the resolver → podcast row flips
|
||||
to ``FAILED``, return dict carries ``reason='billing_resolution_failed'``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fakes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, obj):
|
||||
self._obj = obj
|
||||
|
||||
def scalars(self):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._obj
|
||||
|
||||
def filter(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, podcast):
|
||||
self._podcast = podcast
|
||||
self.commit_count = 0
|
||||
|
||||
async def execute(self, _stmt):
|
||||
return _FakeExecResult(self._podcast)
|
||||
|
||||
async def commit(self):
|
||||
self.commit_count += 1
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
return None
|
||||
|
||||
|
||||
class _FakeSessionMaker:
|
||||
def __init__(self, session: _FakeSession):
|
||||
self._session = session
|
||||
|
||||
def __call__(self):
|
||||
return self._session
|
||||
|
||||
|
||||
def _make_podcast(podcast_id: int = 7, thread_id: int = 99) -> SimpleNamespace:
|
||||
"""Stand-in for a ``Podcast`` row. Importing ``PodcastStatus`` lazily
|
||||
inside helpers keeps this fixture cheap."""
|
||||
return SimpleNamespace(
|
||||
id=podcast_id,
|
||||
title="Test Podcast",
|
||||
thread_id=thread_id,
|
||||
status=None,
|
||||
podcast_transcript=None,
|
||||
file_location=None,
|
||||
)
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _ok_billable_call(**kwargs):
|
||||
"""Stand-in for ``billable_call`` that records its kwargs and yields a
|
||||
no-op accumulator-shaped object."""
|
||||
_CALL_LOG.append(kwargs)
|
||||
yield SimpleNamespace()
|
||||
|
||||
|
||||
_CALL_LOG: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _denying_billable_call(**kwargs):
|
||||
from app.services.billable_calls import QuotaInsufficientError
|
||||
|
||||
_CALL_LOG.append(kwargs)
|
||||
raise QuotaInsufficientError(
|
||||
usage_type=kwargs.get("usage_type", "?"),
|
||||
used_micros=5_000_000,
|
||||
limit_micros=5_000_000,
|
||||
remaining_micros=0,
|
||||
)
|
||||
yield SimpleNamespace() # pragma: no cover — for grammar only
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _settlement_failing_billable_call(**kwargs):
|
||||
from app.services.billable_calls import BillingSettlementError
|
||||
|
||||
_CALL_LOG.append(kwargs)
|
||||
yield SimpleNamespace()
|
||||
raise BillingSettlementError(
|
||||
usage_type=kwargs.get("usage_type", "?"),
|
||||
user_id=kwargs["user_id"],
|
||||
cause=RuntimeError("finalize failed"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_call_log():
|
||||
_CALL_LOG.clear()
|
||||
yield
|
||||
_CALL_LOG.clear()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch):
|
||||
"""Happy path: free billing tier still wraps the graph call so the
|
||||
audit row is recorded. Verifies kwargs threading."""
|
||||
from app.config import config as app_config
|
||||
from app.db import PodcastStatus
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
podcast = _make_podcast(podcast_id=7, thread_id=99)
|
||||
session = _FakeSession(podcast)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
user_id = uuid4()
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
assert search_space_id == 555
|
||||
assert thread_id == 99
|
||||
return user_id, "free", "openrouter/some-free-model"
|
||||
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
|
||||
)
|
||||
monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {
|
||||
"podcast_transcript": [
|
||||
SimpleNamespace(speaker_id=0, dialog="Hi"),
|
||||
SimpleNamespace(speaker_id=1, dialog="Hello"),
|
||||
],
|
||||
"final_podcast_file_path": "/tmp/podcast.wav",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||
|
||||
result = await podcast_tasks._generate_content_podcast(
|
||||
podcast_id=7,
|
||||
source_content="hello world",
|
||||
search_space_id=555,
|
||||
user_prompt="make it short",
|
||||
)
|
||||
|
||||
assert result["status"] == "ready"
|
||||
assert result["podcast_id"] == 7
|
||||
assert podcast.status == PodcastStatus.READY
|
||||
assert podcast.file_location == "/tmp/podcast.wav"
|
||||
|
||||
assert len(_CALL_LOG) == 1
|
||||
call = _CALL_LOG[0]
|
||||
assert call["user_id"] == user_id
|
||||
assert call["search_space_id"] == 555
|
||||
assert call["billing_tier"] == "free"
|
||||
assert call["base_model"] == "openrouter/some-free-model"
|
||||
assert call["usage_type"] == "podcast_generation"
|
||||
assert (
|
||||
call["quota_reserve_micros_override"]
|
||||
== app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS
|
||||
)
|
||||
# Background artifact audit rows intentionally omit the TokenUsage.thread_id
|
||||
# FK to avoid coupling Celery audit commits to an active chat transaction.
|
||||
assert "thread_id" not in call
|
||||
assert call["call_details"] == {
|
||||
"podcast_id": 7,
|
||||
"title": "Test Podcast",
|
||||
"thread_id": 99,
|
||||
}
|
||||
assert callable(call["billable_session_factory"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billable_call_invoked_with_premium_tier(monkeypatch):
|
||||
"""Premium resolution flows through to ``billable_call`` so the
|
||||
reserve/finalize path triggers."""
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
podcast = _make_podcast()
|
||||
session = _FakeSession(podcast)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
user_id = uuid4()
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return user_id, "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
|
||||
)
|
||||
monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"}
|
||||
|
||||
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||
|
||||
await podcast_tasks._generate_content_podcast(
|
||||
podcast_id=7,
|
||||
source_content="hi",
|
||||
search_space_id=555,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert _CALL_LOG[0]["billing_tier"] == "premium"
|
||||
assert _CALL_LOG[0]["base_model"] == "gpt-5.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypatch):
|
||||
"""When ``billable_call`` denies the reservation, the graph never
|
||||
runs and the podcast row flips to FAILED with the documented reason
|
||||
code."""
|
||||
from app.db import PodcastStatus
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
podcast = _make_podcast(podcast_id=8)
|
||||
session = _FakeSession(podcast)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return uuid4(), "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
|
||||
)
|
||||
monkeypatch.setattr(podcast_tasks, "billable_call", _denying_billable_call)
|
||||
|
||||
graph_invoked = []
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
graph_invoked.append(True)
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||
|
||||
result = await podcast_tasks._generate_content_podcast(
|
||||
podcast_id=8,
|
||||
source_content="hi",
|
||||
search_space_id=555,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"podcast_id": 8,
|
||||
"reason": "premium_quota_exhausted",
|
||||
}
|
||||
assert podcast.status == PodcastStatus.FAILED
|
||||
assert graph_invoked == [] # Graph never ran on denied reservation.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billing_settlement_failure_marks_podcast_failed(monkeypatch):
|
||||
from app.db import PodcastStatus
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
podcast = _make_podcast(podcast_id=10)
|
||||
session = _FakeSession(podcast)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return uuid4(), "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "billable_call", _settlement_failing_billable_call
|
||||
)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"}
|
||||
|
||||
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||
|
||||
result = await podcast_tasks._generate_content_podcast(
|
||||
podcast_id=10,
|
||||
source_content="hi",
|
||||
search_space_id=555,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"podcast_id": 10,
|
||||
"reason": "billing_settlement_failed",
|
||||
}
|
||||
assert podcast.status == PodcastStatus.FAILED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolver_failure_marks_podcast_failed(monkeypatch):
|
||||
"""If the resolver raises (e.g. search-space deleted), the task fails
|
||||
cleanly without invoking the graph."""
|
||||
from app.db import PodcastStatus
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
podcast = _make_podcast(podcast_id=9)
|
||||
session = _FakeSession(podcast)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _failing_resolver(sess, search_space_id, *, thread_id=None):
|
||||
raise ValueError("Search space 555 not found")
|
||||
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "_resolve_agent_billing_for_search_space", _failing_resolver
|
||||
)
|
||||
|
||||
graph_invoked = []
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
graph_invoked.append(True)
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||
|
||||
result = await podcast_tasks._generate_content_podcast(
|
||||
podcast_id=9,
|
||||
source_content="hi",
|
||||
search_space_id=555,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"podcast_id": 9,
|
||||
"reason": "billing_resolution_failed",
|
||||
}
|
||||
assert podcast.status == PodcastStatus.FAILED
|
||||
assert graph_invoked == []
|
||||
Loading…
Add table
Add a link
Reference in a new issue