mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-12 20:45:20 +02:00
test(podcasts): relocate stateful tests to integration
Move the lifecycle service, Celery task bodies, and mark_failed coverage out of DB-faking unit tests and into integration tests against a real Postgres, faking only true externals (broker, object store, TTS, ffmpeg, billing, LLM). Add HTTP slices for cancel, voices, scoping, and public-chat streaming. The unit tier is now fake-free pure logic with no session doubles.
This commit is contained in:
parent
8f38737ad9
commit
c84525897b
17 changed files with 985 additions and 465 deletions
319
surfsense_backend/tests/integration/podcasts/conftest.py
Normal file
319
surfsense_backend/tests/integration/podcasts/conftest.py
Normal file
|
|
@ -0,0 +1,319 @@
|
|||
"""Podcast API + task integration fixtures.
|
||||
|
||||
The app's DB session and current-user dependencies ride the test's transactional
|
||||
`db_session`, so seeded rows and rows touched through the endpoints (or the task
|
||||
bodies) share one transaction that rolls back per test. Only true externals are
|
||||
faked: the Celery broker (`*_task.delay`) is captured instead of dispatched, the
|
||||
object store is a tiny in-memory backend, the Celery tasks' own session maker is
|
||||
bound to the test transaction, and — for the render task — the TTS provider and
|
||||
the FFmpeg merge are stubbed. `TTS_SERVICE` is pinned so the deterministic brief
|
||||
proposal can resolve voices.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.app import app, limiter
|
||||
from app.config import config as app_config
|
||||
from app.db import SearchSpace, User, get_async_session
|
||||
from app.routes.search_spaces_routes import create_default_roles_and_membership
|
||||
from app.podcasts.persistence import Podcast, PodcastStatus
|
||||
from app.podcasts.schemas import (
|
||||
DurationTarget,
|
||||
PodcastSpec,
|
||||
PodcastStyle,
|
||||
SpeakerRole,
|
||||
SpeakerSpec,
|
||||
Transcript,
|
||||
TranscriptTurn,
|
||||
)
|
||||
from app.podcasts.service import PodcastService
|
||||
from app.podcasts.tts import SynthesisRequest, SynthesizedAudio, TextToSpeech
|
||||
from app.users import current_active_user
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
limiter.enabled = False
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(
|
||||
db_session: AsyncSession,
|
||||
db_user: User,
|
||||
) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||
async def override_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
yield db_session
|
||||
|
||||
async def override_user() -> User:
|
||||
return db_user
|
||||
|
||||
previous_overrides = app.dependency_overrides.copy()
|
||||
app.dependency_overrides[get_async_session] = override_session
|
||||
app.dependency_overrides[current_active_user] = override_user
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
transport=ASGITransport(app=app),
|
||||
base_url="http://test",
|
||||
timeout=30.0,
|
||||
follow_redirects=False,
|
||||
) as test_client:
|
||||
yield test_client
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
app.dependency_overrides.update(previous_overrides)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def tts_service(monkeypatch) -> str:
|
||||
"""Pin a provider with language-agnostic voices so brief proposal resolves."""
|
||||
service = "openai/tts-1"
|
||||
monkeypatch.setattr(app_config, "TTS_SERVICE", service)
|
||||
return service
|
||||
|
||||
|
||||
class CapturedTasks:
|
||||
"""Records the args each podcast Celery task was enqueued with."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.draft: list[tuple] = []
|
||||
self.render: list[tuple] = []
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def captured_tasks(monkeypatch) -> CapturedTasks:
|
||||
"""Capture `*_task.delay` instead of hitting the broker (a boundary)."""
|
||||
captured = CapturedTasks()
|
||||
from app.podcasts.tasks import draft_transcript_task, render_audio_task
|
||||
|
||||
monkeypatch.setattr(
|
||||
draft_transcript_task, "delay", lambda *a, **k: captured.draft.append((a, k))
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
render_audio_task, "delay", lambda *a, **k: captured.render.append((a, k))
|
||||
)
|
||||
return captured
|
||||
|
||||
|
||||
class FakeStorageBackend:
|
||||
"""In-memory object store standing in for the real audio backend."""
|
||||
|
||||
backend_name = "memory"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.objects: dict[str, bytes] = {}
|
||||
self.deleted: list[str] = []
|
||||
|
||||
async def put(self, key: str, data: bytes, content_type: str | None = None) -> None:
|
||||
self.objects[key] = data
|
||||
|
||||
async def open_stream(self, key: str) -> AsyncIterator[bytes]:
|
||||
yield self.objects.get(key, b"audio-bytes")
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
self.deleted.append(key)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_storage(monkeypatch) -> FakeStorageBackend:
|
||||
"""Route audio storage to an in-memory backend for the stream routes."""
|
||||
backend = FakeStorageBackend()
|
||||
monkeypatch.setattr(
|
||||
"app.podcasts.storage.get_storage_backend", lambda: backend
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.file_storage.factory.get_storage_backend", lambda: backend
|
||||
)
|
||||
return backend
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bind_task_session(db_session: AsyncSession, monkeypatch) -> AsyncSession:
|
||||
"""Bind the Celery tasks' own session maker to the test transaction.
|
||||
|
||||
Task bodies open ``get_celery_session_maker()()`` rather than receiving a
|
||||
session, so this hands them the test's session without closing it on exit; a
|
||||
task's ``commit()`` then releases a savepoint and the per-test rollback still
|
||||
cleans up.
|
||||
"""
|
||||
|
||||
def _make_session():
|
||||
@contextlib.asynccontextmanager
|
||||
async def _ctx() -> AsyncIterator[AsyncSession]:
|
||||
yield db_session
|
||||
|
||||
return _ctx()
|
||||
|
||||
for module in (
|
||||
"app.podcasts.tasks.draft",
|
||||
"app.podcasts.tasks.render",
|
||||
"app.podcasts.tasks.runtime",
|
||||
):
|
||||
monkeypatch.setattr(
|
||||
f"{module}.get_celery_session_maker", lambda: _make_session
|
||||
)
|
||||
return db_session
|
||||
|
||||
|
||||
class FakeTextToSpeech(TextToSpeech):
|
||||
"""In-memory TTS provider: every segment yields fixed bytes (the boundary)."""
|
||||
|
||||
@property
|
||||
def container(self) -> str:
|
||||
return "mp3"
|
||||
|
||||
async def synthesize(self, request: SynthesisRequest) -> SynthesizedAudio:
|
||||
return SynthesizedAudio(data=b"segment-audio", container="mp3")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_tts(monkeypatch) -> FakeTextToSpeech:
|
||||
"""Stand in for the configured TTS provider in the render task."""
|
||||
provider = FakeTextToSpeech()
|
||||
monkeypatch.setattr(
|
||||
"app.podcasts.tasks.render.get_text_to_speech", lambda: provider
|
||||
)
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_merge(monkeypatch) -> None:
|
||||
"""Stub the FFmpeg merge (an external binary) to emit a fixed MP3."""
|
||||
|
||||
async def _merge(segment_paths: list[Path], output_path: Path) -> None:
|
||||
output_path.write_bytes(b"merged-audio")
|
||||
|
||||
monkeypatch.setattr("app.podcasts.rendering.renderer.concat_to_mp3", _merge)
|
||||
|
||||
|
||||
def build_spec(
|
||||
*,
|
||||
language: str = "en",
|
||||
voice_ids: tuple[str, str] = ("openai:alloy", "openai:nova"),
|
||||
) -> PodcastSpec:
|
||||
"""A valid two-speaker brief; tests override only what they assert on."""
|
||||
return PodcastSpec(
|
||||
language=language,
|
||||
style=PodcastStyle.CONVERSATIONAL,
|
||||
speakers=[
|
||||
SpeakerSpec(slot=0, name="Host", role=SpeakerRole.HOST, voice_id=voice_ids[0]),
|
||||
SpeakerSpec(slot=1, name="Guest", role=SpeakerRole.GUEST, voice_id=voice_ids[1]),
|
||||
],
|
||||
duration=DurationTarget(min_minutes=10, max_minutes=20),
|
||||
)
|
||||
|
||||
|
||||
def build_transcript() -> Transcript:
|
||||
return Transcript(
|
||||
turns=[
|
||||
TranscriptTurn(speaker=0, text="Welcome to the show."),
|
||||
TranscriptTurn(speaker=1, text="Glad to be here."),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_podcast(db_session: AsyncSession):
|
||||
"""Create a podcast advanced to a target lifecycle state via the service.
|
||||
|
||||
Setup runs through the same public service the API uses, on the test's
|
||||
session, so the endpoint under test reads a realistically-built row.
|
||||
"""
|
||||
|
||||
_LADDER = [
|
||||
PodcastStatus.AWAITING_BRIEF,
|
||||
PodcastStatus.DRAFTING,
|
||||
PodcastStatus.AWAITING_REVIEW,
|
||||
PodcastStatus.RENDERING,
|
||||
PodcastStatus.READY,
|
||||
]
|
||||
|
||||
async def _make(
|
||||
*,
|
||||
search_space_id: int,
|
||||
status: PodcastStatus = PodcastStatus.AWAITING_BRIEF,
|
||||
title: str = "Test Podcast",
|
||||
thread_id: int | None = None,
|
||||
) -> Podcast:
|
||||
service = PodcastService(db_session)
|
||||
podcast = await service.create(
|
||||
title=title, search_space_id=search_space_id, thread_id=thread_id
|
||||
)
|
||||
if status is PodcastStatus.PENDING:
|
||||
await db_session.flush()
|
||||
return podcast
|
||||
|
||||
targets = _LADDER[: _LADDER.index(status) + 1]
|
||||
for target in targets:
|
||||
if target is PodcastStatus.AWAITING_BRIEF:
|
||||
await service.attach_brief(podcast, build_spec())
|
||||
elif target is PodcastStatus.DRAFTING:
|
||||
await service.begin_drafting(podcast)
|
||||
elif target is PodcastStatus.AWAITING_REVIEW:
|
||||
await service.attach_transcript(podcast, build_transcript())
|
||||
elif target is PodcastStatus.RENDERING:
|
||||
await service.approve(podcast)
|
||||
elif target is PodcastStatus.READY:
|
||||
await service.attach_audio(
|
||||
podcast,
|
||||
storage_backend="memory",
|
||||
storage_key="podcasts/audio.mp3",
|
||||
duration_seconds=123,
|
||||
)
|
||||
await db_session.flush()
|
||||
return podcast
|
||||
|
||||
return _make
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def act_as():
|
||||
"""Switch the authenticated user for subsequent requests on ``client``.
|
||||
|
||||
The ``client`` fixture installs db_user and restores the prior overrides on
|
||||
teardown, so re-pointing the auth dependency here is undone per test.
|
||||
"""
|
||||
|
||||
def _act(user: User) -> None:
|
||||
app.dependency_overrides[current_active_user] = lambda: user
|
||||
|
||||
return _act
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_other_user(db_session: AsyncSession) -> User:
|
||||
"""A second user who is not a member of ``db_search_space``."""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="stranger@surfsense.net",
|
||||
hashed_password="hashed",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def foreign_podcast(
|
||||
db_session: AsyncSession, db_other_user: User, make_podcast
|
||||
) -> Podcast:
|
||||
"""A podcast in a space owned by the other user, invisible to db_user."""
|
||||
space = SearchSpace(name="Stranger Space", user_id=db_other_user.id)
|
||||
db_session.add(space)
|
||||
await db_session.flush()
|
||||
await create_default_roles_and_membership(db_session, space.id, db_other_user.id)
|
||||
await db_session.flush()
|
||||
return await make_podcast(search_space_id=space.id, title="Foreign")
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
"""The brief review gate: edit the spec, then approve to start drafting.
|
||||
|
||||
Covers what the user can do while ``awaiting_brief`` — edit the brief under
|
||||
optimistic concurrency and approve it — and the HTTP status codes the service's
|
||||
guards map to when an edit races or comes too late.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
BASE = "/api/v1/podcasts"
|
||||
|
||||
|
||||
async def _create(client, search_space_id: int) -> dict:
|
||||
resp = await client.post(
|
||||
BASE,
|
||||
json={
|
||||
"title": "Episode",
|
||||
"search_space_id": search_space_id,
|
||||
"source_content": "Source content.",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def test_approve_brief_starts_drafting_and_enqueues_draft(
|
||||
client, db_search_space, captured_tasks
|
||||
):
|
||||
podcast = await _create(client, db_search_space.id)
|
||||
|
||||
resp = await client.post(f"{BASE}/{podcast['id']}/brief/approve")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "drafting"
|
||||
assert captured_tasks.draft == [((podcast["id"], db_search_space.id), {})]
|
||||
assert captured_tasks.render == []
|
||||
|
||||
|
||||
async def test_update_spec_bumps_version_and_persists(client, db_search_space):
|
||||
podcast = await _create(client, db_search_space.id)
|
||||
spec = podcast["spec"]
|
||||
spec["focus"] = "A sharper angle"
|
||||
|
||||
resp = await client.patch(
|
||||
f"{BASE}/{podcast['id']}/spec",
|
||||
json={"spec": spec, "expected_version": podcast["spec_version"]},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["spec_version"] == podcast["spec_version"] + 1
|
||||
assert body["spec"]["focus"] == "A sharper angle"
|
||||
assert body["status"] == "awaiting_brief"
|
||||
|
||||
|
||||
async def test_update_spec_with_stale_version_conflicts(client, db_search_space):
|
||||
podcast = await _create(client, db_search_space.id)
|
||||
|
||||
resp = await client.patch(
|
||||
f"{BASE}/{podcast['id']}/spec",
|
||||
json={"spec": podcast["spec"], "expected_version": 999},
|
||||
)
|
||||
|
||||
assert resp.status_code == 409
|
||||
|
||||
|
||||
async def test_update_spec_after_approval_is_rejected(client, db_search_space):
|
||||
podcast = await _create(client, db_search_space.id)
|
||||
await client.post(f"{BASE}/{podcast['id']}/brief/approve")
|
||||
|
||||
resp = await client.patch(
|
||||
f"{BASE}/{podcast['id']}/spec",
|
||||
json={"spec": podcast["spec"], "expected_version": podcast["spec_version"]},
|
||||
)
|
||||
|
||||
assert resp.status_code == 409
|
||||
39
surfsense_backend/tests/integration/podcasts/test_cancel.py
Normal file
39
surfsense_backend/tests/integration/podcasts/test_cancel.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
"""Cancelling a podcast: allowed while in flight, refused once terminal.
|
||||
|
||||
Cancellation is a user escape hatch from any non-terminal state; a podcast that
|
||||
has already finished (READY) has no exit, so the disallowed transition surfaces
|
||||
as 409.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.podcasts.persistence import PodcastStatus
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
BASE = "/api/v1/podcasts"
|
||||
|
||||
|
||||
async def test_cancel_from_a_live_state_succeeds(
|
||||
client, db_search_space, make_podcast
|
||||
):
|
||||
podcast = await make_podcast(
|
||||
search_space_id=db_search_space.id, status=PodcastStatus.AWAITING_BRIEF
|
||||
)
|
||||
|
||||
resp = await client.post(f"{BASE}/{podcast.id}/cancel")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "cancelled"
|
||||
|
||||
|
||||
async def test_cancel_from_a_terminal_state_conflicts(
|
||||
client, db_search_space, make_podcast
|
||||
):
|
||||
podcast = await make_podcast(
|
||||
search_space_id=db_search_space.id, status=PodcastStatus.READY
|
||||
)
|
||||
|
||||
resp = await client.post(f"{BASE}/{podcast.id}/cancel")
|
||||
|
||||
assert resp.status_code == 409
|
||||
51
surfsense_backend/tests/integration/podcasts/test_create.py
Normal file
51
surfsense_backend/tests/integration/podcasts/test_create.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
"""Creating a podcast proposes a brief and opens the review gate.
|
||||
|
||||
Driven through the real POST endpoint (auth + DB on one transaction): the row is
|
||||
created, a brief is proposed inline from defaults, and the podcast lands in
|
||||
``awaiting_brief`` with a complete spec and nothing generated yet.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
BASE = "/api/v1/podcasts"
|
||||
|
||||
|
||||
async def test_create_proposes_brief_and_opens_gate(client, db_search_space):
|
||||
resp = await client.post(
|
||||
BASE,
|
||||
json={
|
||||
"title": "My Episode",
|
||||
"search_space_id": db_search_space.id,
|
||||
"source_content": "A long piece of source content about a topic.",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 201
|
||||
body = resp.json()
|
||||
assert body["title"] == "My Episode"
|
||||
assert body["status"] == "awaiting_brief"
|
||||
assert body["spec_version"] == 1
|
||||
assert body["spec"] is not None
|
||||
assert body["spec"]["language"] == "en"
|
||||
assert len(body["spec"]["speakers"]) == 2
|
||||
assert body["transcript"] is None
|
||||
assert body["has_audio"] is False
|
||||
|
||||
|
||||
async def test_create_honors_requested_speaker_count(client, db_search_space):
|
||||
resp = await client.post(
|
||||
BASE,
|
||||
json={
|
||||
"title": "Solo",
|
||||
"search_space_id": db_search_space.id,
|
||||
"source_content": "Content.",
|
||||
"speaker_count": 3,
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 201
|
||||
assert len(resp.json()["spec"]["speakers"]) == 3
|
||||
115
surfsense_backend/tests/integration/podcasts/test_draft_task.py
Normal file
115
surfsense_backend/tests/integration/podcasts/test_draft_task.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
"""The transcript-drafting task against a real database.
|
||||
|
||||
Drafting is the expensive LLM step, so it runs under ``billable_call``. The
|
||||
behavior that protects users' money: when billing succeeds, a drafted transcript
|
||||
opens the review gate (DRAFTING -> AWAITING_REVIEW); when billing denies or
|
||||
settlement fails, the podcast ends FAILED with no transcript left behind. The DB,
|
||||
service, and transcript persistence run for real; only the true externals are
|
||||
faked — billing (the metering boundary) and the generation graph (the LLM).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.podcasts.persistence import PodcastStatus
|
||||
from app.podcasts.service import read_transcript
|
||||
from app.podcasts.tasks import draft
|
||||
from app.services.billable_calls import (
|
||||
BillingSettlementError,
|
||||
QuotaInsufficientError,
|
||||
)
|
||||
|
||||
from .conftest import build_transcript
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
def _wire_billing(monkeypatch, *, billable_call, transcript=None) -> None:
|
||||
"""Replace the billing + LLM externals the draft body reaches for."""
|
||||
|
||||
async def _resolver(_session, _search_space_id, *, thread_id=None):
|
||||
return uuid4(), "free", "openrouter/model"
|
||||
|
||||
async def _ainvoke(_state, config=None):
|
||||
return {"transcript": transcript}
|
||||
|
||||
monkeypatch.setattr(draft, "_resolve_agent_billing_for_search_space", _resolver)
|
||||
monkeypatch.setattr(draft, "billable_call", billable_call)
|
||||
monkeypatch.setattr(draft, "transcript_graph", SimpleNamespace(ainvoke=_ainvoke))
|
||||
|
||||
|
||||
async def test_successful_billing_opens_review_gate_with_transcript(
|
||||
monkeypatch, db_search_space, make_podcast, bind_task_session
|
||||
):
|
||||
podcast = await make_podcast(
|
||||
search_space_id=db_search_space.id, status=PodcastStatus.DRAFTING
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _ok(**_kwargs):
|
||||
yield SimpleNamespace()
|
||||
|
||||
_wire_billing(monkeypatch, billable_call=_ok, transcript=build_transcript())
|
||||
|
||||
result = await draft._draft_transcript(podcast.id, db_search_space.id)
|
||||
|
||||
assert result["status"] == "awaiting_review"
|
||||
assert podcast.status == PodcastStatus.AWAITING_REVIEW
|
||||
assert read_transcript(podcast) is not None
|
||||
|
||||
|
||||
async def test_quota_denial_fails_the_podcast_without_a_transcript(
|
||||
monkeypatch, db_search_space, make_podcast, bind_task_session
|
||||
):
|
||||
podcast = await make_podcast(
|
||||
search_space_id=db_search_space.id, status=PodcastStatus.DRAFTING
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _deny(**_kwargs):
|
||||
raise QuotaInsufficientError(
|
||||
usage_type="podcast_generation",
|
||||
used_micros=5_000_000,
|
||||
limit_micros=5_000_000,
|
||||
remaining_micros=0,
|
||||
)
|
||||
yield # pragma: no cover - unreachable, satisfies the CM protocol
|
||||
|
||||
_wire_billing(monkeypatch, billable_call=_deny)
|
||||
|
||||
result = await draft._draft_transcript(podcast.id, db_search_space.id)
|
||||
|
||||
assert result["reason"] == "quota"
|
||||
assert podcast.status == PodcastStatus.FAILED
|
||||
assert read_transcript(podcast) is None
|
||||
|
||||
|
||||
async def test_billing_settlement_failure_fails_the_podcast(
|
||||
monkeypatch, db_search_space, make_podcast, bind_task_session
|
||||
):
|
||||
podcast = await make_podcast(
|
||||
search_space_id=db_search_space.id, status=PodcastStatus.DRAFTING
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _settlement_fails(**_kwargs):
|
||||
yield SimpleNamespace()
|
||||
raise BillingSettlementError(
|
||||
usage_type="podcast_generation",
|
||||
user_id=uuid4(),
|
||||
cause=RuntimeError("finalize failed"),
|
||||
)
|
||||
|
||||
_wire_billing(
|
||||
monkeypatch, billable_call=_settlement_fails, transcript=build_transcript()
|
||||
)
|
||||
|
||||
result = await draft._draft_transcript(podcast.id, db_search_space.id)
|
||||
|
||||
assert result["reason"] == "billing"
|
||||
assert podcast.status == PodcastStatus.FAILED
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
"""Public (unauthenticated) podcast streaming from a chat snapshot.
|
||||
|
||||
A shared chat snapshot carries each podcast's stored-audio key; the public route
|
||||
streams those bytes from the object store via ``share_token`` with no auth. A
|
||||
podcast that isn't in the snapshot is a 404.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.db import NewChatThread, PublicChatSnapshot, User
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
async def _snapshot(db_session, *, search_space_id, user: User, token: str, podcasts):
|
||||
thread = NewChatThread(
|
||||
title="Shared", search_space_id=search_space_id, created_by_id=user.id
|
||||
)
|
||||
db_session.add(thread)
|
||||
await db_session.flush()
|
||||
snapshot = PublicChatSnapshot(
|
||||
thread_id=thread.id,
|
||||
share_token=token,
|
||||
content_hash=f"hash-{token}",
|
||||
message_ids=[],
|
||||
snapshot_data={"podcasts": podcasts},
|
||||
)
|
||||
db_session.add(snapshot)
|
||||
await db_session.flush()
|
||||
|
||||
|
||||
async def test_public_stream_serves_audio_via_storage_key(
|
||||
client, db_session, db_search_space, db_user, fake_storage
|
||||
):
|
||||
await _snapshot(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
user=db_user,
|
||||
token="tok-audio",
|
||||
podcasts=[{"original_id": 555, "storage_key": "podcasts/x.mp3"}],
|
||||
)
|
||||
fake_storage.objects["podcasts/x.mp3"] = b"public-audio"
|
||||
|
||||
resp = await client.get("/api/v1/public/tok-audio/podcasts/555/stream")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["content-type"] == "audio/mpeg"
|
||||
assert resp.content == b"public-audio"
|
||||
|
||||
|
||||
async def test_public_stream_404_when_podcast_absent_from_snapshot(
|
||||
client, db_session, db_search_space, db_user
|
||||
):
|
||||
await _snapshot(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
user=db_user,
|
||||
token="tok-empty",
|
||||
podcasts=[],
|
||||
)
|
||||
|
||||
resp = await client.get("/api/v1/public/tok-empty/podcasts/999/stream")
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
"""The audio-rendering task against a real database.
|
||||
|
||||
From RENDERING, the task synthesises and merges the approved transcript, stores
|
||||
the bytes, and marks the podcast READY with the storage location recorded. The
|
||||
DB, service, renderer orchestration, and storage wrapper run for real; the true
|
||||
externals are faked — the TTS provider, the FFmpeg merge, and the object store.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.podcasts.persistence import PodcastStatus
|
||||
from app.podcasts.tasks import render
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
async def test_render_marks_ready_and_stores_audio(
|
||||
db_search_space, make_podcast, bind_task_session, fake_tts, fake_merge, fake_storage
|
||||
):
|
||||
podcast = await make_podcast(
|
||||
search_space_id=db_search_space.id, status=PodcastStatus.RENDERING
|
||||
)
|
||||
|
||||
result = await render._render_audio(podcast.id)
|
||||
|
||||
assert result["status"] == "ready"
|
||||
assert podcast.status == PodcastStatus.READY
|
||||
assert podcast.storage_backend == "memory"
|
||||
assert podcast.storage_key
|
||||
assert fake_storage.objects[podcast.storage_key] == b"merged-audio"
|
||||
53
surfsense_backend/tests/integration/podcasts/test_scoping.py
Normal file
53
surfsense_backend/tests/integration/podcasts/test_scoping.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""Podcasts are scoped to search-space membership.
|
||||
|
||||
A user can only create or read podcasts in spaces they belong to, and an
|
||||
unscoped listing returns only the caller's own podcasts — never another
|
||||
member's.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
BASE = "/api/v1/podcasts"
|
||||
|
||||
|
||||
async def test_reading_a_podcast_in_a_nonmember_space_is_forbidden(
|
||||
client, db_search_space, make_podcast, act_as, db_other_user
|
||||
):
|
||||
podcast = await make_podcast(search_space_id=db_search_space.id)
|
||||
act_as(db_other_user)
|
||||
|
||||
resp = await client.get(f"{BASE}/{podcast.id}")
|
||||
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
async def test_creating_in_a_nonmember_space_is_forbidden(
|
||||
client, db_search_space, act_as, db_other_user
|
||||
):
|
||||
act_as(db_other_user)
|
||||
|
||||
resp = await client.post(
|
||||
BASE,
|
||||
json={
|
||||
"title": "X",
|
||||
"search_space_id": db_search_space.id,
|
||||
"source_content": "content",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
async def test_listing_returns_only_the_callers_podcasts(
|
||||
client, db_search_space, make_podcast, foreign_podcast
|
||||
):
|
||||
mine = await make_podcast(search_space_id=db_search_space.id, title="Mine")
|
||||
|
||||
resp = await client.get(BASE)
|
||||
|
||||
assert resp.status_code == 200
|
||||
ids = {p["id"] for p in resp.json()}
|
||||
assert mine.id in ids
|
||||
assert foreign_podcast.id not in ids
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
"""Streaming a podcast's rendered audio over HTTP.
|
||||
|
||||
A ready podcast streams its bytes from the storage backend; a podcast with no
|
||||
stored audio returns 404. Storage is an in-memory backend (the object store is a
|
||||
system boundary).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.podcasts.persistence import PodcastStatus
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
BASE = "/api/v1/podcasts"
|
||||
|
||||
|
||||
async def test_stream_serves_stored_audio(
|
||||
client, db_search_space, make_podcast, fake_storage
|
||||
):
|
||||
podcast = await make_podcast(
|
||||
search_space_id=db_search_space.id, status=PodcastStatus.READY
|
||||
)
|
||||
fake_storage.objects["podcasts/audio.mp3"] = b"the-audio"
|
||||
|
||||
resp = await client.get(f"{BASE}/{podcast.id}/stream")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["content-type"] == "audio/mpeg"
|
||||
assert resp.content == b"the-audio"
|
||||
|
||||
|
||||
async def test_stream_404_when_no_audio(client, db_search_space, make_podcast):
|
||||
podcast = await make_podcast(
|
||||
search_space_id=db_search_space.id, status=PodcastStatus.AWAITING_REVIEW
|
||||
)
|
||||
|
||||
resp = await client.get(f"{BASE}/{podcast.id}/stream")
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
"""The task failure safety net (``mark_failed``) against a real database.
|
||||
|
||||
When a task body raises, ``mark_failed`` records the reason on the row. Its
|
||||
contract has two halves worth securing: a still-running podcast moves to FAILED
|
||||
with the reason, while one that already reached a terminal state is left exactly
|
||||
as it was rather than forced. A missing row is a no-op, never a crash.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.podcasts.persistence import PodcastStatus
|
||||
from app.podcasts.tasks import runtime
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
async def test_marking_failed_records_the_reason_on_a_running_podcast(
|
||||
db_search_space, make_podcast, bind_task_session
|
||||
):
|
||||
podcast = await make_podcast(
|
||||
search_space_id=db_search_space.id, status=PodcastStatus.DRAFTING
|
||||
)
|
||||
|
||||
await runtime.mark_failed(podcast.id, "tts provider unavailable")
|
||||
|
||||
assert podcast.status == PodcastStatus.FAILED
|
||||
assert podcast.error == "tts provider unavailable"
|
||||
|
||||
|
||||
async def test_marking_failed_leaves_an_already_terminal_podcast_untouched(
|
||||
db_search_space, make_podcast, bind_task_session
|
||||
):
|
||||
podcast = await make_podcast(
|
||||
search_space_id=db_search_space.id, status=PodcastStatus.READY
|
||||
)
|
||||
|
||||
await runtime.mark_failed(podcast.id, "too late")
|
||||
|
||||
assert podcast.status == PodcastStatus.READY
|
||||
|
||||
|
||||
async def test_marking_a_missing_podcast_failed_is_a_no_op(bind_task_session):
|
||||
await runtime.mark_failed(987654321, "gone") # must not raise
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
"""The transcript go/no-go gate: approve to render, or regenerate to redraft.
|
||||
|
||||
From ``awaiting_review`` the user either approves (start rendering) or regenerates
|
||||
(redraft). These pin the resulting state, the Celery task each enqueues, and the
|
||||
HTTP codes for acting from the wrong state (409) or without a transcript (422).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.podcasts.persistence import Podcast, PodcastStatus
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
BASE = "/api/v1/podcasts"
|
||||
|
||||
|
||||
async def test_approve_transcript_starts_rendering_and_enqueues_render(
|
||||
client, db_search_space, make_podcast, captured_tasks
|
||||
):
|
||||
podcast = await make_podcast(
|
||||
search_space_id=db_search_space.id, status=PodcastStatus.AWAITING_REVIEW
|
||||
)
|
||||
|
||||
resp = await client.post(f"{BASE}/{podcast.id}/transcript/approve")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "rendering"
|
||||
assert captured_tasks.render == [((podcast.id,), {})]
|
||||
assert captured_tasks.draft == []
|
||||
|
||||
|
||||
async def test_regenerate_returns_to_drafting_and_enqueues_draft(
|
||||
client, db_search_space, make_podcast, captured_tasks
|
||||
):
|
||||
podcast = await make_podcast(
|
||||
search_space_id=db_search_space.id, status=PodcastStatus.AWAITING_REVIEW
|
||||
)
|
||||
|
||||
resp = await client.post(f"{BASE}/{podcast.id}/transcript/regenerate")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "drafting"
|
||||
assert captured_tasks.draft == [((podcast.id, db_search_space.id), {})]
|
||||
assert captured_tasks.render == []
|
||||
|
||||
|
||||
async def test_approve_transcript_from_terminal_state_is_rejected(
|
||||
client, db_search_space, make_podcast, captured_tasks
|
||||
):
|
||||
# A ready podcast still has its transcript, so the precondition passes and
|
||||
# the disallowed terminal->rendering transition is what surfaces (409).
|
||||
podcast = await make_podcast(
|
||||
search_space_id=db_search_space.id, status=PodcastStatus.READY
|
||||
)
|
||||
|
||||
resp = await client.post(f"{BASE}/{podcast.id}/transcript/approve")
|
||||
|
||||
assert resp.status_code == 409
|
||||
assert captured_tasks.render == []
|
||||
|
||||
|
||||
async def test_approve_without_transcript_is_unprocessable(
|
||||
client, db_session, db_search_space, captured_tasks
|
||||
):
|
||||
# An anomalous awaiting_review row with no transcript exercises the route's
|
||||
# precondition->422 mapping (the service refuses to render without one).
|
||||
podcast = Podcast(
|
||||
title="No transcript",
|
||||
search_space_id=db_search_space.id,
|
||||
status=PodcastStatus.AWAITING_REVIEW,
|
||||
spec_version=1,
|
||||
)
|
||||
db_session.add(podcast)
|
||||
await db_session.flush()
|
||||
|
||||
resp = await client.post(f"{BASE}/{podcast.id}/transcript/approve")
|
||||
|
||||
assert resp.status_code == 422
|
||||
assert captured_tasks.render == []
|
||||
31
surfsense_backend/tests/integration/podcasts/test_voices.py
Normal file
31
surfsense_backend/tests/integration/podcasts/test_voices.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
"""GET /podcasts/voices: the active provider's catalog, or 503 if unconfigured.
|
||||
|
||||
The brief UI needs the voices the configured TTS provider offers; with no
|
||||
provider configured there is nothing to choose from, which is a 503 rather than
|
||||
an empty list.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
BASE = "/api/v1/podcasts"
|
||||
|
||||
|
||||
async def test_voices_returns_the_active_providers_catalog(client):
|
||||
resp = await client.get(f"{BASE}/voices")
|
||||
|
||||
assert resp.status_code == 200
|
||||
voices = resp.json()
|
||||
assert voices # openai/tts-1 offers voices
|
||||
assert {"voice_id", "display_name", "language", "gender"} <= voices[0].keys()
|
||||
|
||||
|
||||
async def test_voices_503_when_no_tts_configured(client, monkeypatch):
|
||||
monkeypatch.setattr(app_config, "TTS_SERVICE", "")
|
||||
|
||||
resp = await client.get(f"{BASE}/voices")
|
||||
|
||||
assert resp.status_code == 503
|
||||
|
|
@ -1,10 +1,10 @@
|
|||
"""Shared builders for podcast unit tests.
|
||||
|
||||
These tests exercise the podcast domain through its public interfaces. The only
|
||||
test double is a minimal stand-in for the SQLAlchemy ``AsyncSession`` — a real
|
||||
system boundary — so the service's own repository and state machine run for
|
||||
real. Briefs and transcripts are built with valid factories so each test states
|
||||
just the fields it cares about.
|
||||
These tests exercise pure logic through public interfaces with no test doubles:
|
||||
the brief and transcript factories build valid aggregates so each test states
|
||||
only the fields it cares about. Stateful, persistence-backed paths (the lifecycle
|
||||
service, the Celery task bodies) are covered by the integration suite against a
|
||||
real database.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -22,76 +22,6 @@ from app.podcasts.schemas import (
|
|||
)
|
||||
|
||||
|
||||
class FakeAsyncSession:
|
||||
"""A no-op stand-in for ``AsyncSession`` at the persistence boundary.
|
||||
|
||||
The service flushes to assign state within a unit of work; in a unit test
|
||||
there is no database, so ``add``/``flush`` simply do nothing. Behavior is
|
||||
observed through the returned aggregate, never through this double.
|
||||
"""
|
||||
|
||||
def add(self, _obj: object) -> None:
|
||||
return None
|
||||
|
||||
async def flush(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class FakeCeleryDbSession(FakeAsyncSession):
|
||||
"""An async-context session double for Celery task bodies.
|
||||
|
||||
Task bodies open ``get_celery_session_maker()()`` as an async context,
|
||||
``get`` the row, then ``commit``. This holds one preloaded podcast and
|
||||
records whether the body committed, so tests assert on the row's final
|
||||
state — not on the calls made to get there.
|
||||
"""
|
||||
|
||||
def __init__(self, podcast: object | None = None) -> None:
|
||||
self._podcast = podcast
|
||||
self.committed = False
|
||||
|
||||
async def get(self, _model: object, _id: object) -> object | None:
|
||||
return self._podcast
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.committed = True
|
||||
|
||||
async def __aenter__(self) -> FakeCeleryDbSession:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *_exc: object) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_session() -> FakeAsyncSession:
|
||||
return FakeAsyncSession()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_celery_session():
|
||||
"""Factory for a Celery-style session double holding one podcast."""
|
||||
|
||||
def _make(podcast: object | None = None) -> FakeCeleryDbSession:
|
||||
return FakeCeleryDbSession(podcast)
|
||||
|
||||
return _make
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_maker_for():
|
||||
"""Build a ``get_celery_session_maker`` replacement bound to one session.
|
||||
|
||||
``get_celery_session_maker()()`` must yield the session, so the replacement
|
||||
is a zero-arg callable returning a maker that returns the session.
|
||||
"""
|
||||
|
||||
def _make(session: object):
|
||||
return lambda: (lambda: session)
|
||||
|
||||
return _make
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_spec():
|
||||
"""Factory for a valid :class:`PodcastSpec`; override only what matters."""
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
"""The API read model the frontend renders from.
|
||||
|
||||
``PodcastDetail.of`` is the contract the detail view and action responses
|
||||
depend on: it exposes the deserialized brief and transcript and a simple
|
||||
``has_audio`` flag the client can't derive from the published Zero columns.
|
||||
These tests drive real podcasts through the service, then assert the read model
|
||||
reflects their state.
|
||||
``PodcastDetail.of`` maps a stored podcast row to the detail view and action
|
||||
responses: it exposes the deserialized brief and transcript and a simple
|
||||
``has_audio`` flag the client can't derive from the published Zero columns. Each
|
||||
test builds a row in one lifecycle shape and asserts the mapping reflects it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -14,28 +13,27 @@ from datetime import UTC, datetime
|
|||
import pytest
|
||||
|
||||
from app.podcasts.api.schemas import PodcastDetail
|
||||
from app.podcasts.persistence import PodcastStatus
|
||||
from app.podcasts.service import PodcastService
|
||||
from app.podcasts.persistence import Podcast, PodcastStatus
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _stamp(podcast):
|
||||
"""Give a transient row the id and created_at a persisted one would have.
|
||||
|
||||
A detail response is only ever built from a saved podcast; without a real
|
||||
database, we stand in the primary key and timestamp the DB would assign.
|
||||
"""
|
||||
def _podcast(*, status: PodcastStatus = PodcastStatus.PENDING, **columns) -> Podcast:
|
||||
"""A persisted-looking row: the id and created_at a saved podcast would carry."""
|
||||
podcast = Podcast(
|
||||
title="Episode",
|
||||
search_space_id=3,
|
||||
status=status,
|
||||
spec_version=1,
|
||||
**columns,
|
||||
)
|
||||
podcast.id = 1
|
||||
podcast.created_at = datetime.now(UTC)
|
||||
return podcast
|
||||
|
||||
|
||||
async def test_a_fresh_podcast_exposes_no_brief_transcript_or_audio(fake_session):
|
||||
service = PodcastService(fake_session)
|
||||
podcast = _stamp(await service.create(title="New", search_space_id=3))
|
||||
|
||||
detail = PodcastDetail.of(podcast)
|
||||
def test_a_fresh_podcast_exposes_no_brief_transcript_or_audio():
|
||||
detail = PodcastDetail.of(_podcast())
|
||||
|
||||
assert detail.status == PodcastStatus.PENDING
|
||||
assert detail.spec is None
|
||||
|
|
@ -43,12 +41,11 @@ async def test_a_fresh_podcast_exposes_no_brief_transcript_or_audio(fake_session
|
|||
assert detail.has_audio is False
|
||||
|
||||
|
||||
async def test_an_awaiting_brief_podcast_exposes_the_deserialized_brief(
|
||||
fake_session, make_spec
|
||||
):
|
||||
service = PodcastService(fake_session)
|
||||
podcast = _stamp(await service.create(title="Brief", search_space_id=3))
|
||||
await service.attach_brief(podcast, make_spec(language="fr"))
|
||||
def test_an_awaiting_brief_podcast_exposes_the_deserialized_brief(make_spec):
|
||||
podcast = _podcast(
|
||||
status=PodcastStatus.AWAITING_BRIEF,
|
||||
spec=make_spec(language="fr").model_dump(mode="json"),
|
||||
)
|
||||
|
||||
detail = PodcastDetail.of(podcast)
|
||||
|
||||
|
|
@ -56,17 +53,14 @@ async def test_an_awaiting_brief_podcast_exposes_the_deserialized_brief(
|
|||
assert detail.spec.language == "fr"
|
||||
|
||||
|
||||
async def test_a_ready_podcast_reports_available_audio(
|
||||
fake_session, make_spec, make_transcript
|
||||
):
|
||||
service = PodcastService(fake_session)
|
||||
podcast = _stamp(await service.create(title="Done", search_space_id=3))
|
||||
await service.attach_brief(podcast, make_spec())
|
||||
await service.begin_drafting(podcast)
|
||||
await service.attach_transcript(podcast, make_transcript())
|
||||
await service.approve(podcast)
|
||||
await service.attach_audio(
|
||||
podcast, storage_backend="local", storage_key="k", duration_seconds=120
|
||||
def test_a_ready_podcast_reports_available_audio(make_spec, make_transcript):
|
||||
podcast = _podcast(
|
||||
status=PodcastStatus.READY,
|
||||
spec=make_spec().model_dump(mode="json"),
|
||||
podcast_transcript=make_transcript().model_dump(mode="json"),
|
||||
storage_backend="local",
|
||||
storage_key="k",
|
||||
duration_seconds=120,
|
||||
)
|
||||
|
||||
detail = PodcastDetail.of(podcast)
|
||||
|
|
|
|||
|
|
@ -1,135 +0,0 @@
|
|||
"""The transcript-drafting task's billing gate.
|
||||
|
||||
Drafting is the expensive LLM step, so it runs under ``billable_call``. The
|
||||
behavior that protects users' money: if billing denies the reservation the
|
||||
podcast must end FAILED with no transcript, and only when billing succeeds does
|
||||
a drafted transcript open the review gate. These tests fake the true
|
||||
boundaries — the database, the billing system, and the generation graph — and
|
||||
assert the podcast's resulting state, never how those boundaries were called.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.podcasts.persistence import Podcast, PodcastStatus
|
||||
from app.podcasts.service import read_transcript
|
||||
from app.podcasts.tasks import draft
|
||||
from app.services.billable_calls import (
|
||||
BillingSettlementError,
|
||||
QuotaInsufficientError,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _drafting_podcast(make_spec) -> Podcast:
|
||||
"""A podcast already at DRAFTING with an approved brief, as the API leaves it."""
|
||||
podcast = Podcast(
|
||||
title="Episode",
|
||||
search_space_id=42,
|
||||
status=PodcastStatus.DRAFTING,
|
||||
spec_version=1,
|
||||
)
|
||||
podcast.id = 1
|
||||
podcast.thread_id = None
|
||||
podcast.spec = make_spec().model_dump(mode="json")
|
||||
podcast.source_content = "Some source material to discuss."
|
||||
return podcast
|
||||
|
||||
|
||||
def _wire_boundaries(monkeypatch, *, session, billable_call, transcript=None):
|
||||
"""Replace every external dependency the task body reaches for."""
|
||||
monkeypatch.setattr(draft, "get_celery_session_maker", lambda: (lambda: session))
|
||||
|
||||
async def _resolver(_session, _search_space_id, *, thread_id=None):
|
||||
return uuid4(), "free", "openrouter/model"
|
||||
|
||||
monkeypatch.setattr(
|
||||
draft, "_resolve_agent_billing_for_search_space", _resolver
|
||||
)
|
||||
monkeypatch.setattr(draft, "billable_call", billable_call)
|
||||
|
||||
async def _ainvoke(_state, config=None):
|
||||
return {"transcript": transcript}
|
||||
|
||||
monkeypatch.setattr(draft, "transcript_graph", SimpleNamespace(ainvoke=_ainvoke))
|
||||
|
||||
|
||||
async def test_successful_billing_opens_the_review_gate_with_a_transcript(
|
||||
monkeypatch, make_celery_session, make_spec, make_transcript
|
||||
):
|
||||
podcast = _drafting_podcast(make_spec)
|
||||
session = make_celery_session(podcast)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _ok(**_kwargs):
|
||||
yield SimpleNamespace()
|
||||
|
||||
_wire_boundaries(
|
||||
monkeypatch, session=session, billable_call=_ok, transcript=make_transcript()
|
||||
)
|
||||
|
||||
result = await draft._draft_transcript(podcast_id=1, search_space_id=42)
|
||||
|
||||
assert podcast.status == PodcastStatus.AWAITING_REVIEW
|
||||
assert read_transcript(podcast) is not None
|
||||
assert result["status"] == "awaiting_review"
|
||||
|
||||
|
||||
async def test_quota_denial_fails_the_podcast_without_a_transcript(
|
||||
monkeypatch, make_celery_session, make_spec
|
||||
):
|
||||
"""A denied reservation must not leave a half-drafted, billable mess."""
|
||||
podcast = _drafting_podcast(make_spec)
|
||||
session = make_celery_session(podcast)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _deny(**_kwargs):
|
||||
raise QuotaInsufficientError(
|
||||
usage_type="podcast_generation",
|
||||
used_micros=5_000_000,
|
||||
limit_micros=5_000_000,
|
||||
remaining_micros=0,
|
||||
)
|
||||
yield # pragma: no cover - unreachable, satisfies the CM protocol
|
||||
|
||||
_wire_boundaries(monkeypatch, session=session, billable_call=_deny)
|
||||
|
||||
result = await draft._draft_transcript(podcast_id=1, search_space_id=42)
|
||||
|
||||
assert podcast.status == PodcastStatus.FAILED
|
||||
assert read_transcript(podcast) is None
|
||||
assert result["reason"] == "quota"
|
||||
|
||||
|
||||
async def test_billing_settlement_failure_fails_the_podcast(
|
||||
monkeypatch, make_celery_session, make_spec, make_transcript
|
||||
):
|
||||
podcast = _drafting_podcast(make_spec)
|
||||
session = make_celery_session(podcast)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _settlement_fails(**_kwargs):
|
||||
yield SimpleNamespace()
|
||||
raise BillingSettlementError(
|
||||
usage_type="podcast_generation",
|
||||
user_id=uuid4(),
|
||||
cause=RuntimeError("finalize failed"),
|
||||
)
|
||||
|
||||
_wire_boundaries(
|
||||
monkeypatch,
|
||||
session=session,
|
||||
billable_call=_settlement_fails,
|
||||
transcript=make_transcript(),
|
||||
)
|
||||
|
||||
result = await draft._draft_transcript(podcast_id=1, search_space_id=42)
|
||||
|
||||
assert podcast.status == PodcastStatus.FAILED
|
||||
assert result["reason"] == "billing"
|
||||
|
|
@ -1,163 +0,0 @@
|
|||
"""The podcast lifecycle: the guarantees the rest of the system relies on.
|
||||
|
||||
These tests drive the aggregate through :class:`PodcastService`'s public
|
||||
methods and observe the resulting status and stored brief/transcript — the
|
||||
domain's contract. They say nothing about how the service stores or flushes,
|
||||
so they survive any refactor that preserves the lifecycle.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.podcasts.persistence import PodcastStatus
|
||||
from app.podcasts.service import (
|
||||
InvalidTransition,
|
||||
PodcastService,
|
||||
PreconditionFailed,
|
||||
SpecConflict,
|
||||
read_spec,
|
||||
read_transcript,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
async def test_a_podcast_progresses_from_creation_to_ready(
|
||||
fake_session, make_spec, make_transcript
|
||||
):
|
||||
"""The full happy path: create → brief → draft → review → render → ready."""
|
||||
service = PodcastService(fake_session)
|
||||
|
||||
podcast = await service.create(title="Episode 1", search_space_id=7)
|
||||
assert podcast.status == PodcastStatus.PENDING
|
||||
|
||||
spec = make_spec()
|
||||
await service.attach_brief(podcast, spec)
|
||||
assert podcast.status == PodcastStatus.AWAITING_BRIEF
|
||||
assert read_spec(podcast) == spec
|
||||
|
||||
await service.begin_drafting(podcast)
|
||||
assert podcast.status == PodcastStatus.DRAFTING
|
||||
|
||||
transcript = make_transcript()
|
||||
await service.attach_transcript(podcast, transcript)
|
||||
assert podcast.status == PodcastStatus.AWAITING_REVIEW
|
||||
assert read_transcript(podcast) == transcript
|
||||
|
||||
await service.approve(podcast)
|
||||
assert podcast.status == PodcastStatus.RENDERING
|
||||
|
||||
await service.attach_audio(
|
||||
podcast, storage_backend="local", storage_key="k", duration_seconds=42
|
||||
)
|
||||
assert podcast.status == PodcastStatus.READY
|
||||
assert podcast.duration_seconds == 42
|
||||
|
||||
|
||||
async def test_drafting_requires_an_approved_brief(fake_session):
|
||||
"""A brief must exist before drafting can begin."""
|
||||
service = PodcastService(fake_session)
|
||||
podcast = await service.create(title="No brief", search_space_id=1)
|
||||
|
||||
with pytest.raises(PreconditionFailed):
|
||||
await service.begin_drafting(podcast)
|
||||
|
||||
|
||||
async def test_rendering_requires_a_transcript(fake_session, make_spec):
|
||||
"""Approval to render is refused when no transcript has been drafted."""
|
||||
service = PodcastService(fake_session)
|
||||
podcast = await service.create(title="No transcript", search_space_id=1)
|
||||
await service.attach_brief(podcast, make_spec())
|
||||
await service.begin_drafting(podcast)
|
||||
|
||||
with pytest.raises(PreconditionFailed):
|
||||
await service.approve(podcast)
|
||||
|
||||
|
||||
async def test_regenerate_returns_a_reviewed_transcript_to_drafting(
|
||||
fake_session, make_spec, make_transcript
|
||||
):
|
||||
"""At the go/no-go gate, rejecting sends the podcast back to drafting."""
|
||||
service = PodcastService(fake_session)
|
||||
podcast = await service.create(title="Redo", search_space_id=1)
|
||||
await service.attach_brief(podcast, make_spec())
|
||||
await service.begin_drafting(podcast)
|
||||
await service.attach_transcript(podcast, make_transcript())
|
||||
|
||||
await service.regenerate(podcast)
|
||||
|
||||
assert podcast.status == PodcastStatus.DRAFTING
|
||||
|
||||
|
||||
async def test_brief_can_be_edited_at_the_gate_and_bumps_its_version(
|
||||
fake_session, make_spec
|
||||
):
|
||||
"""Editing the brief while awaiting review records it and advances version."""
|
||||
service = PodcastService(fake_session)
|
||||
podcast = await service.create(title="Editable", search_space_id=1)
|
||||
await service.attach_brief(podcast, make_spec(language="en"))
|
||||
starting_version = podcast.spec_version
|
||||
|
||||
await service.update_spec(podcast, make_spec(language="fr"), starting_version)
|
||||
|
||||
assert read_spec(podcast).language == "fr"
|
||||
assert podcast.spec_version == starting_version + 1
|
||||
|
||||
|
||||
async def test_editing_a_brief_with_a_stale_version_conflicts(
|
||||
fake_session, make_spec
|
||||
):
|
||||
"""A concurrent edit racing on a stale version is rejected, not silently lost."""
|
||||
service = PodcastService(fake_session)
|
||||
podcast = await service.create(title="Raced", search_space_id=1)
|
||||
await service.attach_brief(podcast, make_spec())
|
||||
current = podcast.spec_version
|
||||
|
||||
with pytest.raises(SpecConflict):
|
||||
await service.update_spec(podcast, make_spec(language="es"), current - 1)
|
||||
|
||||
|
||||
async def test_brief_cannot_be_edited_after_the_gate_closes(
|
||||
fake_session, make_spec
|
||||
):
|
||||
"""Once drafting starts, the brief is settled and edits are refused."""
|
||||
service = PodcastService(fake_session)
|
||||
podcast = await service.create(title="Locked", search_space_id=1)
|
||||
await service.attach_brief(podcast, make_spec())
|
||||
await service.begin_drafting(podcast)
|
||||
|
||||
with pytest.raises(InvalidTransition):
|
||||
await service.update_spec(podcast, make_spec(language="es"), podcast.spec_version)
|
||||
|
||||
|
||||
async def test_a_podcast_can_be_cancelled_while_in_flight(fake_session, make_spec):
|
||||
"""Cancellation is available from a non-terminal state."""
|
||||
service = PodcastService(fake_session)
|
||||
podcast = await service.create(title="Abort", search_space_id=1)
|
||||
await service.attach_brief(podcast, make_spec())
|
||||
|
||||
await service.cancel(podcast)
|
||||
|
||||
assert podcast.status == PodcastStatus.CANCELLED
|
||||
|
||||
|
||||
async def test_failure_records_a_reason(fake_session):
|
||||
"""Failing a podcast captures a human-readable reason."""
|
||||
service = PodcastService(fake_session)
|
||||
podcast = await service.create(title="Boom", search_space_id=1)
|
||||
|
||||
await service.fail(podcast, "tts provider unavailable")
|
||||
|
||||
assert podcast.status == PodcastStatus.FAILED
|
||||
assert podcast.error == "tts provider unavailable"
|
||||
|
||||
|
||||
async def test_terminal_podcasts_reject_further_transitions(fake_session):
|
||||
"""A finished podcast cannot be cancelled or otherwise moved."""
|
||||
service = PodcastService(fake_session)
|
||||
podcast = await service.create(title="Done", search_space_id=1)
|
||||
await service.cancel(podcast)
|
||||
|
||||
with pytest.raises(InvalidTransition):
|
||||
await service.fail(podcast, "too late")
|
||||
|
|
@ -1,57 +0,0 @@
|
|||
"""Failure recording shared by the podcast tasks.
|
||||
|
||||
When a task body raises, ``mark_failed`` is the safety net that records the
|
||||
reason on the row. Its contract has two halves worth securing: a still-running
|
||||
podcast is moved to FAILED with the reason, and a podcast that already reached a
|
||||
terminal state is left exactly as it was rather than forced. Only the database
|
||||
(a real boundary) is doubled; the lifecycle service runs for real.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.podcasts.persistence import Podcast, PodcastStatus
|
||||
from app.podcasts.tasks import runtime
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _podcast(status: PodcastStatus) -> Podcast:
|
||||
podcast = Podcast(title="Episode", search_space_id=1, status=status, spec_version=1)
|
||||
podcast.id = 1
|
||||
return podcast
|
||||
|
||||
|
||||
async def test_marking_failed_records_the_reason_on_a_running_podcast(
|
||||
monkeypatch, session_maker_for, make_celery_session
|
||||
):
|
||||
podcast = _podcast(PodcastStatus.DRAFTING)
|
||||
session = make_celery_session(podcast)
|
||||
monkeypatch.setattr(runtime, "get_celery_session_maker", session_maker_for(session))
|
||||
|
||||
await runtime.mark_failed(1, "tts provider unavailable")
|
||||
|
||||
assert podcast.status == PodcastStatus.FAILED
|
||||
assert podcast.error == "tts provider unavailable"
|
||||
|
||||
|
||||
async def test_marking_failed_leaves_an_already_terminal_podcast_untouched(
|
||||
monkeypatch, session_maker_for, make_celery_session
|
||||
):
|
||||
podcast = _podcast(PodcastStatus.CANCELLED)
|
||||
session = make_celery_session(podcast)
|
||||
monkeypatch.setattr(runtime, "get_celery_session_maker", session_maker_for(session))
|
||||
|
||||
await runtime.mark_failed(1, "too late")
|
||||
|
||||
assert podcast.status == PodcastStatus.CANCELLED
|
||||
|
||||
|
||||
async def test_marking_a_missing_podcast_failed_is_a_no_op(
|
||||
monkeypatch, session_maker_for, make_celery_session
|
||||
):
|
||||
session = make_celery_session(None)
|
||||
monkeypatch.setattr(runtime, "get_celery_session_maker", session_maker_for(session))
|
||||
|
||||
await runtime.mark_failed(999, "gone") # must not raise
|
||||
Loading…
Add table
Add a link
Reference in a new issue