refactor(routes): replace user variable with auth context in tests

This commit is contained in:
Anish Sarkar 2026-06-20 03:34:40 +05:30
parent af5a112212
commit 1e8baa10ec
5 changed files with 34 additions and 36 deletions

View file

@ -24,7 +24,6 @@ from app.db import (
Permission, Permission,
SearchSpace, SearchSpace,
SearchSpaceMembership, SearchSpaceMembership,
User,
get_async_session, get_async_session,
) )
from app.podcasts.generation.brief import propose_brief from app.podcasts.generation.brief import propose_brief
@ -71,7 +70,7 @@ async def list_podcasts(
raise HTTPException(status_code=400, detail="Invalid pagination parameters") raise HTTPException(status_code=400, detail="Invalid pagination parameters")
if search_space_id is not None: if search_space_id is not None:
await _require(session, user, search_space_id, Permission.PODCASTS_READ) await _require(session, auth, search_space_id, Permission.PODCASTS_READ)
query = ( query = (
select(Podcast) select(Podcast)
.where(Podcast.search_space_id == search_space_id) .where(Podcast.search_space_id == search_space_id)
@ -136,7 +135,6 @@ async def preview_voice(
voice_id: str, voice_id: str,
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""A short audio sample of a voice, so users pick by sound.""" """A short audio sample of a voice, so users pick by sound."""
if not app_config.TTS_SERVICE: if not app_config.TTS_SERVICE:
raise HTTPException(status_code=503, detail="No TTS provider configured") raise HTTPException(status_code=503, detail="No TTS provider configured")
@ -161,8 +159,7 @@ async def create_podcast(
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user await _require(session, auth, body.search_space_id, Permission.PODCASTS_CREATE)
await _require(session, user, body.search_space_id, Permission.PODCASTS_CREATE)
service = PodcastService(session) service = PodcastService(session)
podcast = await service.create( podcast = await service.create(
@ -191,8 +188,7 @@ async def get_podcast(
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_READ)
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_READ)
return PodcastDetail.of(podcast) return PodcastDetail.of(podcast)
@ -203,8 +199,7 @@ async def update_spec(
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE)
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE)
async with _lifecycle_errors(): async with _lifecycle_errors():
await PodcastService(session).update_spec( await PodcastService(session).update_spec(
podcast, body.spec, body.expected_version podcast, body.spec, body.expected_version
@ -219,9 +214,8 @@ async def approve_brief(
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Approve the brief and start drafting the transcript.""" """Approve the brief and start drafting the transcript."""
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE)
async with _lifecycle_errors(): async with _lifecycle_errors():
await PodcastService(session).begin_drafting(podcast) await PodcastService(session).begin_drafting(podcast)
await session.commit() await session.commit()
@ -237,9 +231,8 @@ async def regenerate_transcript(
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Reopen the brief gate for a fresh take; drafting waits for re-approval.""" """Reopen the brief gate for a fresh take; drafting waits for re-approval."""
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE)
async with _lifecycle_errors(): async with _lifecycle_errors():
await PodcastService(session).regenerate(podcast) await PodcastService(session).regenerate(podcast)
await session.commit() await session.commit()
@ -252,9 +245,8 @@ async def revert_regeneration(
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Back out of a regeneration and return to the finished episode.""" """Back out of a regeneration and return to the finished episode."""
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE)
async with _lifecycle_errors(): async with _lifecycle_errors():
await PodcastService(session).revert_regeneration(podcast) await PodcastService(session).revert_regeneration(podcast)
await session.commit() await session.commit()
@ -267,8 +259,7 @@ async def cancel_podcast(
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE)
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE)
async with _lifecycle_errors(): async with _lifecycle_errors():
await PodcastService(session).cancel(podcast) await PodcastService(session).cancel(podcast)
await session.commit() await session.commit()
@ -281,8 +272,7 @@ async def delete_podcast(
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_DELETE)
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_DELETE)
await purge_audio(podcast) await purge_audio(podcast)
await session.delete(podcast) await session.delete(podcast)
await session.commit() await session.commit()
@ -295,8 +285,7 @@ async def stream_podcast(
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
auth: AuthContext = Depends(get_auth_context), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_READ)
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_READ)
if podcast.storage_key: if podcast.storage_key:
# Verify first so a missing object is a 404, not a mid-stream crash. # Verify first so a missing object is a 404, not a mid-stream crash.
@ -339,7 +328,6 @@ async def _require(
search_space_id: int, search_space_id: int,
permission: Permission, permission: Permission,
) -> None: ) -> None:
user = auth.user
await check_permission( await check_permission(
session, session,
auth, auth,
@ -355,11 +343,10 @@ async def _load(
podcast_id: int, podcast_id: int,
permission: Permission, permission: Permission,
) -> Podcast: ) -> Podcast:
user = auth.user
podcast = await PodcastRepository(session).get(podcast_id) podcast = await PodcastRepository(session).get(podcast_id)
if podcast is None: if podcast is None:
raise HTTPException(status_code=404, detail="Podcast not found") raise HTTPException(status_code=404, detail="Podcast not found")
await _require(session, user, podcast.search_space_id, permission) await _require(session, auth, podcast.search_space_id, permission)
return podcast return podcast

View file

@ -23,6 +23,7 @@ import pytest
from fastapi import HTTPException from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import ( from app.db import (
SearchSourceConnector, SearchSourceConnector,
SearchSourceConnectorType, SearchSourceConnectorType,
@ -109,7 +110,7 @@ class TestConnectorIndexCrossSpaceAuthz:
connector_id=connector_a.id, connector_id=connector_a.id,
search_space_id=space_b.id, # the attacker's own space search_space_id=space_b.id, # the attacker's own space
session=db_session, session=db_session,
user=attacker, auth=AuthContext.session(attacker),
) )
assert exc_info.value.status_code == 404 assert exc_info.value.status_code == 404
@ -140,7 +141,7 @@ class TestConnectorIndexCrossSpaceAuthz:
connector_id=connector.id, connector_id=connector.id,
search_space_id=space.id, # the connector's own space search_space_id=space.id, # the connector's own space
session=db_session, session=db_session,
user=owner, auth=AuthContext.session(owner),
) )
check_permission_mock.assert_awaited_once() check_permission_mock.assert_awaited_once()

View file

@ -38,7 +38,9 @@ async def cleanup_supervisors():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_byo_long_poll_noops_when_mode_is_webhook(monkeypatch): async def test_start_byo_long_poll_noops_when_mode_is_webhook(monkeypatch):
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_ENABLED", True)
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "webhook") monkeypatch.setattr(byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "webhook")
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled")
await byo_long_poll.start_byo_long_poll_supervisors() await byo_long_poll.start_byo_long_poll_supervisors()
@ -47,9 +49,11 @@ async def test_start_byo_long_poll_noops_when_mode_is_webhook(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_byo_long_poll_noops_when_no_byo_accounts(mocker, monkeypatch): async def test_start_byo_long_poll_noops_when_no_byo_accounts(mocker, monkeypatch):
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_ENABLED", True)
monkeypatch.setattr( monkeypatch.setattr(
byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll" byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll"
) )
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled")
session = mocker.AsyncMock() session = mocker.AsyncMock()
session.execute.return_value = ScalarResult([]) session.execute.return_value = ScalarResult([])
monkeypatch.setattr( monkeypatch.setattr(
@ -67,9 +71,11 @@ async def test_start_byo_long_poll_noops_when_no_byo_accounts(mocker, monkeypatc
async def test_start_byo_long_poll_spawns_one_supervisor_per_account( async def test_start_byo_long_poll_spawns_one_supervisor_per_account(
mocker, monkeypatch mocker, monkeypatch
): ):
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_ENABLED", True)
monkeypatch.setattr( monkeypatch.setattr(
byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll" byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll"
) )
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled")
accounts = [mocker.Mock(id=1), mocker.Mock(id=2)] accounts = [mocker.Mock(id=1), mocker.Mock(id=2)]
session = mocker.AsyncMock() session = mocker.AsyncMock()
session.execute.return_value = ScalarResult(accounts) session.execute.return_value = ScalarResult(accounts)
@ -115,9 +121,11 @@ async def test_supervisor_retries_after_run_returns(mocker, monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_shutdown_cancels_running_supervisors(mocker, monkeypatch): async def test_shutdown_cancels_running_supervisors(mocker, monkeypatch):
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_ENABLED", True)
monkeypatch.setattr( monkeypatch.setattr(
byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll" byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll"
) )
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled")
session = mocker.AsyncMock() session = mocker.AsyncMock()
session.execute.return_value = ScalarResult([mocker.Mock(id=1)]) session.execute.return_value = ScalarResult([mocker.Mock(id=1)])
monkeypatch.setattr( monkeypatch.setattr(

View file

@ -27,6 +27,7 @@ async def test_inbox_worker_claims_and_processes_in_fastapi_process(
async def test_start_stop_gateway_inbox_worker(mocker, monkeypatch): async def test_start_stop_gateway_inbox_worker(mocker, monkeypatch):
started = asyncio.Event() started = asyncio.Event()
stopped = asyncio.Event() stopped = asyncio.Event()
monkeypatch.setattr(inbox_worker.config, "GATEWAY_ENABLED", True)
async def run_forever(): async def run_forever():
started.set() started.set()

View file

@ -19,6 +19,7 @@ from unittest.mock import AsyncMock, patch
import pytest import pytest
from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
from app.auth.context import AuthContext
from app.routes import agent_revert_route from app.routes import agent_revert_route
from app.services.revert_service import RevertOutcome from app.services.revert_service import RevertOutcome
@ -147,7 +148,7 @@ class TestFlagGuard:
thread_id=1, thread_id=1,
chat_turn_id="42:1700000000000", chat_turn_id="42:1700000000000",
session=session, session=session,
user=_FakeUser(), auth=AuthContext.session(_FakeUser()),
) )
assert getattr(exc.value, "status_code", None) == 503 assert getattr(exc.value, "status_code", None) == 503
@ -167,7 +168,7 @@ class TestRevertTurnDispatch:
thread_id=1, thread_id=1,
chat_turn_id="ct-empty", chat_turn_id="ct-empty",
session=session, session=session,
user=_FakeUser(), auth=AuthContext.session(_FakeUser()),
) )
assert response.status == "ok" assert response.status == "ok"
assert response.total == 0 assert response.total == 0
@ -209,7 +210,7 @@ class TestRevertTurnDispatch:
thread_id=1, thread_id=1,
chat_turn_id="ct-3", chat_turn_id="ct-3",
session=session, session=session,
user=_FakeUser(), auth=AuthContext.session(_FakeUser()),
) )
assert response.status == "ok" assert response.status == "ok"
@ -248,7 +249,7 @@ class TestRevertTurnDispatch:
thread_id=1, thread_id=1,
chat_turn_id="ct-i", chat_turn_id="ct-i",
session=session, session=session,
user=_FakeUser(), auth=AuthContext.session(_FakeUser()),
) )
assert response.status == "ok" assert response.status == "ok"
assert response.already_reverted == 1 assert response.already_reverted == 1
@ -275,7 +276,7 @@ class TestRevertTurnDispatch:
thread_id=1, thread_id=1,
chat_turn_id="ct-rev", chat_turn_id="ct-rev",
session=session, session=session,
user=_FakeUser(), auth=AuthContext.session(_FakeUser()),
) )
assert response.status == "ok" assert response.status == "ok"
assert response.results[0].status == "skipped" assert response.results[0].status == "skipped"
@ -315,7 +316,7 @@ class TestRevertTurnDispatch:
thread_id=1, thread_id=1,
chat_turn_id="ct-mix", chat_turn_id="ct-mix",
session=session, session=session,
user=_FakeUser(), auth=AuthContext.session(_FakeUser()),
) )
assert response.status == "partial" assert response.status == "partial"
assert response.reverted == 1 assert response.reverted == 1
@ -354,7 +355,7 @@ class TestRevertTurnDispatch:
thread_id=1, thread_id=1,
chat_turn_id="ct-fail", chat_turn_id="ct-fail",
session=session, session=session,
user=_FakeUser(), auth=AuthContext.session(_FakeUser()),
) )
assert response.status == "partial" assert response.status == "partial"
assert response.failed == 1 assert response.failed == 1
@ -386,7 +387,7 @@ class TestRevertTurnDispatch:
thread_id=1, thread_id=1,
chat_turn_id="ct-perm", chat_turn_id="ct-perm",
session=session, session=session,
user=_FakeUser(id="not-owner"), auth=AuthContext.session(_FakeUser(id="not-owner")),
) )
assert response.status == "partial" assert response.status == "partial"
assert response.results[0].status == "permission_denied" assert response.results[0].status == "permission_denied"
@ -449,7 +450,7 @@ class TestRevertTurnDispatch:
thread_id=1, thread_id=1,
chat_turn_id="ct-mixed-all", chat_turn_id="ct-mixed-all",
session=session, session=session,
user=_FakeUser(), # only id=7 has a different user_id auth=AuthContext.session(_FakeUser()), # only id=7 has a different user_id
) )
assert response.total == len(rows) == 6 assert response.total == len(rows) == 6
@ -518,7 +519,7 @@ class TestRevertTurnDispatch:
thread_id=1, thread_id=1,
chat_turn_id="ct-race", chat_turn_id="ct-race",
session=session, session=session,
user=_FakeUser(), auth=AuthContext.session(_FakeUser()),
) )
assert response.failed == 0, ( assert response.failed == 0, (