From 1e8baa10ec3635fcb08a0c75ef308f857ca929c2 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 20 Jun 2026 03:34:40 +0530 Subject: [PATCH] refactor(routes): replace user variable with auth context in tests --- surfsense_backend/app/podcasts/api/routes.py | 35 ++++++------------- .../integration/test_connector_index_authz.py | 5 +-- .../gateway/test_byo_long_poll_lifespan.py | 8 +++++ .../tests/unit/gateway/test_inbox_worker.py | 1 + .../unit/routes/test_revert_turn_route.py | 21 +++++------ 5 files changed, 34 insertions(+), 36 deletions(-) diff --git a/surfsense_backend/app/podcasts/api/routes.py b/surfsense_backend/app/podcasts/api/routes.py index 2f4c8e4d9..582b0531e 100644 --- a/surfsense_backend/app/podcasts/api/routes.py +++ b/surfsense_backend/app/podcasts/api/routes.py @@ -24,7 +24,6 @@ from app.db import ( Permission, SearchSpace, SearchSpaceMembership, - User, get_async_session, ) from app.podcasts.generation.brief import propose_brief @@ -71,7 +70,7 @@ async def list_podcasts( raise HTTPException(status_code=400, detail="Invalid pagination parameters") if search_space_id is not None: - await _require(session, user, search_space_id, Permission.PODCASTS_READ) + await _require(session, auth, search_space_id, Permission.PODCASTS_READ) query = ( select(Podcast) .where(Podcast.search_space_id == search_space_id) @@ -136,7 +135,6 @@ async def preview_voice( voice_id: str, auth: AuthContext = Depends(get_auth_context), ): - user = auth.user """A short audio sample of a voice, so users pick by sound.""" if not app_config.TTS_SERVICE: raise HTTPException(status_code=503, detail="No TTS provider configured") @@ -161,8 +159,7 @@ async def create_podcast( session: AsyncSession = Depends(get_async_session), auth: AuthContext = Depends(get_auth_context), ): - user = auth.user - await _require(session, user, body.search_space_id, Permission.PODCASTS_CREATE) + await _require(session, auth, body.search_space_id, Permission.PODCASTS_CREATE) service = PodcastService(session) podcast = await service.create( @@ -191,8 +188,7 @@ async def get_podcast( session: AsyncSession = Depends(get_async_session), auth: AuthContext = Depends(get_auth_context), ): - user = auth.user - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_READ) + podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_READ) return PodcastDetail.of(podcast) @@ -203,8 +199,7 @@ async def update_spec( session: AsyncSession = Depends(get_async_session), auth: AuthContext = Depends(get_auth_context), ): - user = auth.user - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) + podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE) async with _lifecycle_errors(): await PodcastService(session).update_spec( podcast, body.spec, body.expected_version @@ -219,9 +214,8 @@ async def approve_brief( session: AsyncSession = Depends(get_async_session), auth: AuthContext = Depends(get_auth_context), ): - user = auth.user """Approve the brief and start drafting the transcript.""" - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) + podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE) async with _lifecycle_errors(): await PodcastService(session).begin_drafting(podcast) await session.commit() @@ -237,9 +231,8 @@ async def regenerate_transcript( session: AsyncSession = Depends(get_async_session), auth: AuthContext = Depends(get_auth_context), ): - user = auth.user """Reopen the brief gate for a fresh take; drafting waits for re-approval.""" - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) + podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE) async with _lifecycle_errors(): await PodcastService(session).regenerate(podcast) await session.commit() @@ -252,9 +245,8 @@ async def revert_regeneration( session: AsyncSession = Depends(get_async_session), auth: AuthContext = Depends(get_auth_context), ): - user = auth.user """Back out of a regeneration and return to the finished episode.""" - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) + podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE) async with _lifecycle_errors(): await PodcastService(session).revert_regeneration(podcast) await session.commit() @@ -267,8 +259,7 @@ async def cancel_podcast( session: AsyncSession = Depends(get_async_session), auth: AuthContext = Depends(get_auth_context), ): - user = auth.user - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) + podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_UPDATE) async with _lifecycle_errors(): await PodcastService(session).cancel(podcast) await session.commit() @@ -281,8 +272,7 @@ async def delete_podcast( session: AsyncSession = Depends(get_async_session), auth: AuthContext = Depends(get_auth_context), ): - user = auth.user - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_DELETE) + podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_DELETE) await purge_audio(podcast) await session.delete(podcast) await session.commit() @@ -295,8 +285,7 @@ async def stream_podcast( session: AsyncSession = Depends(get_async_session), auth: AuthContext = Depends(get_auth_context), ): - user = auth.user - podcast = await _load(session, user, podcast_id, Permission.PODCASTS_READ) + podcast = await _load(session, auth, podcast_id, Permission.PODCASTS_READ) if podcast.storage_key: # Verify first so a missing object is a 404, not a mid-stream crash. @@ -339,7 +328,6 @@ async def _require( search_space_id: int, permission: Permission, ) -> None: - user = auth.user await check_permission( session, auth, @@ -355,11 +343,10 @@ async def _load( podcast_id: int, permission: Permission, ) -> Podcast: - user = auth.user podcast = await PodcastRepository(session).get(podcast_id) if podcast is None: raise HTTPException(status_code=404, detail="Podcast not found") - await _require(session, user, podcast.search_space_id, permission) + await _require(session, auth, podcast.search_space_id, permission) return podcast diff --git a/surfsense_backend/tests/integration/test_connector_index_authz.py b/surfsense_backend/tests/integration/test_connector_index_authz.py index cea2407cc..b25df7087 100644 --- a/surfsense_backend/tests/integration/test_connector_index_authz.py +++ b/surfsense_backend/tests/integration/test_connector_index_authz.py @@ -23,6 +23,7 @@ import pytest from fastapi import HTTPException from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.db import ( SearchSourceConnector, SearchSourceConnectorType, @@ -109,7 +110,7 @@ class TestConnectorIndexCrossSpaceAuthz: connector_id=connector_a.id, search_space_id=space_b.id, # the attacker's own space session=db_session, - user=attacker, + auth=AuthContext.session(attacker), ) assert exc_info.value.status_code == 404 @@ -140,7 +141,7 @@ class TestConnectorIndexCrossSpaceAuthz: connector_id=connector.id, search_space_id=space.id, # the connector's own space session=db_session, - user=owner, + auth=AuthContext.session(owner), ) check_permission_mock.assert_awaited_once() diff --git a/surfsense_backend/tests/unit/gateway/test_byo_long_poll_lifespan.py b/surfsense_backend/tests/unit/gateway/test_byo_long_poll_lifespan.py index de4386abb..c184af601 100644 --- a/surfsense_backend/tests/unit/gateway/test_byo_long_poll_lifespan.py +++ b/surfsense_backend/tests/unit/gateway/test_byo_long_poll_lifespan.py @@ -38,7 +38,9 @@ async def cleanup_supervisors(): @pytest.mark.asyncio async def test_start_byo_long_poll_noops_when_mode_is_webhook(monkeypatch): + monkeypatch.setattr(byo_long_poll.config, "GATEWAY_ENABLED", True) monkeypatch.setattr(byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "webhook") + monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled") await byo_long_poll.start_byo_long_poll_supervisors() @@ -47,9 +49,11 @@ async def test_start_byo_long_poll_noops_when_mode_is_webhook(monkeypatch): @pytest.mark.asyncio async def test_start_byo_long_poll_noops_when_no_byo_accounts(mocker, monkeypatch): + monkeypatch.setattr(byo_long_poll.config, "GATEWAY_ENABLED", True) monkeypatch.setattr( byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll" ) + monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled") session = mocker.AsyncMock() session.execute.return_value = ScalarResult([]) monkeypatch.setattr( @@ -67,9 +71,11 @@ async def test_start_byo_long_poll_noops_when_no_byo_accounts(mocker, monkeypatc async def test_start_byo_long_poll_spawns_one_supervisor_per_account( mocker, monkeypatch ): + monkeypatch.setattr(byo_long_poll.config, "GATEWAY_ENABLED", True) monkeypatch.setattr( byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll" ) + monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled") accounts = [mocker.Mock(id=1), mocker.Mock(id=2)] session = mocker.AsyncMock() session.execute.return_value = ScalarResult(accounts) @@ -115,9 +121,11 @@ async def test_supervisor_retries_after_run_returns(mocker, monkeypatch): @pytest.mark.asyncio async def test_shutdown_cancels_running_supervisors(mocker, monkeypatch): + monkeypatch.setattr(byo_long_poll.config, "GATEWAY_ENABLED", True) monkeypatch.setattr( byo_long_poll.config, "GATEWAY_TELEGRAM_INTAKE_MODE", "longpoll" ) + monkeypatch.setattr(byo_long_poll.config, "GATEWAY_WHATSAPP_INTAKE_MODE", "disabled") session = mocker.AsyncMock() session.execute.return_value = ScalarResult([mocker.Mock(id=1)]) monkeypatch.setattr( diff --git a/surfsense_backend/tests/unit/gateway/test_inbox_worker.py b/surfsense_backend/tests/unit/gateway/test_inbox_worker.py index 1e5b2a184..0ee661102 100644 --- a/surfsense_backend/tests/unit/gateway/test_inbox_worker.py +++ b/surfsense_backend/tests/unit/gateway/test_inbox_worker.py @@ -27,6 +27,7 @@ async def test_inbox_worker_claims_and_processes_in_fastapi_process( async def test_start_stop_gateway_inbox_worker(mocker, monkeypatch): started = asyncio.Event() stopped = asyncio.Event() + monkeypatch.setattr(inbox_worker.config, "GATEWAY_ENABLED", True) async def run_forever(): started.set() diff --git a/surfsense_backend/tests/unit/routes/test_revert_turn_route.py b/surfsense_backend/tests/unit/routes/test_revert_turn_route.py index 35d409a40..44fcfe042 100644 --- a/surfsense_backend/tests/unit/routes/test_revert_turn_route.py +++ b/surfsense_backend/tests/unit/routes/test_revert_turn_route.py @@ -19,6 +19,7 @@ from unittest.mock import AsyncMock, patch import pytest from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFlags +from app.auth.context import AuthContext from app.routes import agent_revert_route from app.services.revert_service import RevertOutcome @@ -147,7 +148,7 @@ class TestFlagGuard: thread_id=1, chat_turn_id="42:1700000000000", session=session, - user=_FakeUser(), + auth=AuthContext.session(_FakeUser()), ) assert getattr(exc.value, "status_code", None) == 503 @@ -167,7 +168,7 @@ class TestRevertTurnDispatch: thread_id=1, chat_turn_id="ct-empty", session=session, - user=_FakeUser(), + auth=AuthContext.session(_FakeUser()), ) assert response.status == "ok" assert response.total == 0 @@ -209,7 +210,7 @@ class TestRevertTurnDispatch: thread_id=1, chat_turn_id="ct-3", session=session, - user=_FakeUser(), + auth=AuthContext.session(_FakeUser()), ) assert response.status == "ok" @@ -248,7 +249,7 @@ class TestRevertTurnDispatch: thread_id=1, chat_turn_id="ct-i", session=session, - user=_FakeUser(), + auth=AuthContext.session(_FakeUser()), ) assert response.status == "ok" assert response.already_reverted == 1 @@ -275,7 +276,7 @@ class TestRevertTurnDispatch: thread_id=1, chat_turn_id="ct-rev", session=session, - user=_FakeUser(), + auth=AuthContext.session(_FakeUser()), ) assert response.status == "ok" assert response.results[0].status == "skipped" @@ -315,7 +316,7 @@ class TestRevertTurnDispatch: thread_id=1, chat_turn_id="ct-mix", session=session, - user=_FakeUser(), + auth=AuthContext.session(_FakeUser()), ) assert response.status == "partial" assert response.reverted == 1 @@ -354,7 +355,7 @@ class TestRevertTurnDispatch: thread_id=1, chat_turn_id="ct-fail", session=session, - user=_FakeUser(), + auth=AuthContext.session(_FakeUser()), ) assert response.status == "partial" assert response.failed == 1 @@ -386,7 +387,7 @@ class TestRevertTurnDispatch: thread_id=1, chat_turn_id="ct-perm", session=session, - user=_FakeUser(id="not-owner"), + auth=AuthContext.session(_FakeUser(id="not-owner")), ) assert response.status == "partial" assert response.results[0].status == "permission_denied" @@ -449,7 +450,7 @@ class TestRevertTurnDispatch: thread_id=1, chat_turn_id="ct-mixed-all", session=session, - user=_FakeUser(), # only id=7 has a different user_id + auth=AuthContext.session(_FakeUser()), # only id=7 has a different user_id ) assert response.total == len(rows) == 6 @@ -518,7 +519,7 @@ class TestRevertTurnDispatch: thread_id=1, chat_turn_id="ct-race", session=session, - user=_FakeUser(), + auth=AuthContext.session(_FakeUser()), ) assert response.failed == 0, (