diff --git a/surfsense_backend/tests/README.md b/surfsense_backend/tests/README.md index 5764252a5..23b161f99 100644 --- a/surfsense_backend/tests/README.md +++ b/surfsense_backend/tests/README.md @@ -42,7 +42,7 @@ Maximize logic covered by unit tests; keep integration tests for what genuinely `conftest.py` is scoped to its directory and below. Keep truly global fixtures in `tests/conftest.py`; put module-specific fixtures in that module's `conftest.py` so a DB fixture never loads for a pure unit test. -For API integration tests, override `get_async_session` and `current_active_user` to ride the test's transactional `db_session` (see `tests/integration/notifications/conftest.py`): rows seeded in the test and rows read via the endpoint share one transaction that rolls back automatically. +For API integration tests, override `get_async_session` and `get_auth_context` to ride the test's transactional `db_session` (see `tests/integration/notifications/conftest.py`): rows seeded in the test and rows read via the endpoint share one transaction that rolls back automatically. ## Import mode diff --git a/surfsense_backend/tests/integration/chat/test_append_message_recovery.py b/surfsense_backend/tests/integration/chat/test_append_message_recovery.py index a5182a978..c6a40c356 100644 --- a/surfsense_backend/tests/integration/chat/test_append_message_recovery.py +++ b/surfsense_backend/tests/integration/chat/test_append_message_recovery.py @@ -40,6 +40,7 @@ import pytest_asyncio from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.db import ( ChatVisibility, NewChatMessage, @@ -395,7 +396,7 @@ class TestAppendMessageRecoveryAfterFinalize: thread_id=thread_id, request=request, session=db_session, - user=db_user, + auth=AuthContext.session(db_user), ) # Response must echo the SERVER's rich payload, not the FE's @@ -469,7 +470,7 @@ class TestAppendMessageRecoveryAfterFinalize: thread_id=thread_id, request=_FakeRequest(fe_request_body), session=db_session, - user=db_user, + auth=AuthContext.session(db_user), ) assert fe_response.role == NewChatMessageRole.ASSISTANT @@ -552,7 +553,7 @@ class TestAppendMessageRecoveryAfterFinalize: } ), session=db_session, - user=db_user, + auth=AuthContext.session(db_user), ) assert ok_response.role == NewChatMessageRole.USER assert ok_response.turn_id is None diff --git a/surfsense_backend/tests/integration/chat/test_thread_visibility.py b/surfsense_backend/tests/integration/chat/test_thread_visibility.py index 464d389db..ba6f2a66f 100644 --- a/surfsense_backend/tests/integration/chat/test_thread_visibility.py +++ b/surfsense_backend/tests/integration/chat/test_thread_visibility.py @@ -16,6 +16,7 @@ from fastapi import HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.db import ( ChatVisibility, SearchSpace, @@ -33,6 +34,10 @@ from app.schemas.new_chat import ( pytestmark = pytest.mark.integration +def _auth(user: User) -> AuthContext: + return AuthContext.session(user) + + @pytest_asyncio.fixture async def db_member(db_session: AsyncSession, db_search_space: SearchSpace) -> User: member = User( @@ -85,7 +90,7 @@ async def _create_thread( visibility=ChatVisibility.PRIVATE, ), session=db_session, - user=db_user, + auth=_auth(db_user), ) @@ -108,13 +113,13 @@ async def test_private_thread_is_hidden_from_other_search_space_member( member_threads = await new_chat_routes.list_threads( search_space_id=db_search_space.id, session=db_session, - user=db_member, + auth=_auth(db_member), ) member_search = await new_chat_routes.search_threads( search_space_id=db_search_space.id, title="Visibility", session=db_session, - user=db_member, + auth=_auth(db_member), ) assert thread.id not in _active_thread_ids(member_threads) @@ -123,7 +128,7 @@ async def test_private_thread_is_hidden_from_other_search_space_member( await new_chat_routes.get_thread_full( thread_id=thread.id, session=db_session, - user=db_member, + auth=_auth(db_member), ) assert exc_info.value.status_code == 403 @@ -142,24 +147,24 @@ async def test_creator_can_share_thread_and_member_can_list_search_read_it( visibility=ChatVisibility.SEARCH_SPACE, ), session=db_session, - user=db_user, + auth=_auth(db_user), ) member_threads = await new_chat_routes.list_threads( search_space_id=db_search_space.id, session=db_session, - user=db_member, + auth=_auth(db_member), ) member_search = await new_chat_routes.search_threads( search_space_id=db_search_space.id, title="Visibility", session=db_session, - user=db_member, + auth=_auth(db_member), ) full_thread = await new_chat_routes.get_thread_full( thread_id=thread.id, session=db_session, - user=db_member, + auth=_auth(db_member), ) assert updated.visibility == ChatVisibility.SEARCH_SPACE @@ -181,20 +186,20 @@ async def test_rename_and_archive_do_not_reset_shared_visibility( visibility=ChatVisibility.SEARCH_SPACE, ), session=db_session, - user=db_user, + auth=_auth(db_user), ) renamed = await new_chat_routes.update_thread( thread_id=thread.id, thread_update=NewChatThreadUpdate(title="Renamed Shared Chat"), session=db_session, - user=db_user, + auth=_auth(db_user), ) archived = await new_chat_routes.update_thread( thread_id=thread.id, thread_update=NewChatThreadUpdate(archived=True), session=db_session, - user=db_user, + auth=_auth(db_user), ) assert renamed.visibility == ChatVisibility.SEARCH_SPACE @@ -215,7 +220,7 @@ async def test_non_creator_cannot_change_shared_thread_back_to_private( visibility=ChatVisibility.SEARCH_SPACE, ), session=db_session, - user=db_user, + auth=_auth(db_user), ) with pytest.raises(HTTPException) as exc_info: @@ -225,7 +230,7 @@ async def test_non_creator_cannot_change_shared_thread_back_to_private( visibility=ChatVisibility.PRIVATE, ), session=db_session, - user=db_member, + auth=_auth(db_member), ) assert exc_info.value.status_code == 403 @@ -244,7 +249,7 @@ async def test_creator_can_make_shared_thread_private_again( visibility=ChatVisibility.SEARCH_SPACE, ), session=db_session, - user=db_user, + auth=_auth(db_user), ) private_again = await new_chat_routes.update_thread_visibility( @@ -253,18 +258,18 @@ async def test_creator_can_make_shared_thread_private_again( visibility=ChatVisibility.PRIVATE, ), session=db_session, - user=db_user, + auth=_auth(db_user), ) member_threads = await new_chat_routes.list_threads( search_space_id=db_search_space.id, session=db_session, - user=db_member, + auth=_auth(db_member), ) member_search = await new_chat_routes.search_threads( search_space_id=db_search_space.id, title="Visibility", session=db_session, - user=db_member, + auth=_auth(db_member), ) assert private_again.visibility == ChatVisibility.PRIVATE @@ -274,6 +279,6 @@ async def test_creator_can_make_shared_thread_private_again( await new_chat_routes.get_thread_full( thread_id=thread.id, session=db_session, - user=db_member, + auth=_auth(db_member), ) assert exc_info.value.status_code == 403 diff --git a/surfsense_backend/tests/integration/composio/conftest.py b/surfsense_backend/tests/integration/composio/conftest.py index 44d707ec3..578b5b228 100644 --- a/surfsense_backend/tests/integration/composio/conftest.py +++ b/surfsense_backend/tests/integration/composio/conftest.py @@ -15,6 +15,7 @@ from httpx import ASGITransport from sqlalchemy.ext.asyncio import AsyncSession from app.app import app, limiter +from app.auth.context import AuthContext from app.config import config from app.db import ( SearchSourceConnector, @@ -22,7 +23,7 @@ from app.db import ( User, get_async_session, ) -from app.users import current_active_user +from app.users import get_auth_context pytestmark = pytest.mark.integration @@ -40,12 +41,12 @@ async def client( async def override_session() -> AsyncGenerator[AsyncSession, None]: yield db_session - async def override_user() -> User: - return db_user + async def override_auth() -> AuthContext: + return AuthContext.session(db_user) previous_overrides = app.dependency_overrides.copy() app.dependency_overrides[get_async_session] = override_session - app.dependency_overrides[current_active_user] = override_user + app.dependency_overrides[get_auth_context] = override_auth try: async with httpx.AsyncClient( diff --git a/surfsense_backend/tests/integration/notifications/conftest.py b/surfsense_backend/tests/integration/notifications/conftest.py index 17a44a51d..e410d0d55 100644 --- a/surfsense_backend/tests/integration/notifications/conftest.py +++ b/surfsense_backend/tests/integration/notifications/conftest.py @@ -1,9 +1,9 @@ """Notifications integration fixtures. -The app's DB session and current-user dependencies are overridden to ride the +The app's DB session and auth-context dependencies are overridden to ride the test's transactional `db_session`, so API calls and seeded rows share one -transaction that rolls back per test. Overriding `current_active_user` also -bypasses real JWT auth, so these tests don't depend on AUTH_TYPE. +transaction that rolls back per test. Overriding `get_auth_context` also bypasses +real JWT auth, so these tests don't depend on AUTH_TYPE. """ from __future__ import annotations @@ -17,8 +17,9 @@ from httpx import ASGITransport from sqlalchemy.ext.asyncio import AsyncSession from app.app import app, limiter +from app.auth.context import AuthContext from app.db import User, get_async_session -from app.users import current_active_user +from app.users import get_auth_context pytestmark = pytest.mark.integration @@ -33,12 +34,12 @@ async def client( async def override_session() -> AsyncGenerator[AsyncSession, None]: yield db_session - async def override_user() -> User: - return db_user + async def override_auth() -> AuthContext: + return AuthContext.session(db_user) previous_overrides = app.dependency_overrides.copy() app.dependency_overrides[get_async_session] = override_session - app.dependency_overrides[current_active_user] = override_user + app.dependency_overrides[get_auth_context] = override_auth try: async with httpx.AsyncClient( diff --git a/surfsense_backend/tests/integration/podcasts/conftest.py b/surfsense_backend/tests/integration/podcasts/conftest.py index 75248a6a1..067924ad5 100644 --- a/surfsense_backend/tests/integration/podcasts/conftest.py +++ b/surfsense_backend/tests/integration/podcasts/conftest.py @@ -24,6 +24,7 @@ from httpx import ASGITransport from sqlalchemy.ext.asyncio import AsyncSession from app.app import app, limiter +from app.auth.context import AuthContext from app.config import config as app_config from app.db import SearchSpace, User, get_async_session from app.podcasts.persistence import Podcast, PodcastStatus @@ -39,7 +40,7 @@ from app.podcasts.schemas import ( from app.podcasts.service import PodcastService from app.podcasts.tts import SynthesisRequest, SynthesizedAudio, TextToSpeech from app.routes.search_spaces_routes import create_default_roles_and_membership -from app.users import current_active_user +from app.users import get_auth_context pytestmark = pytest.mark.integration @@ -54,12 +55,12 @@ async def client( async def override_session() -> AsyncGenerator[AsyncSession, None]: yield db_session - async def override_user() -> User: - return db_user + async def override_auth() -> AuthContext: + return AuthContext.session(db_user) previous_overrides = app.dependency_overrides.copy() app.dependency_overrides[get_async_session] = override_session - app.dependency_overrides[current_active_user] = override_user + app.dependency_overrides[get_auth_context] = override_auth try: async with httpx.AsyncClient( @@ -290,7 +291,7 @@ def act_as(): """ def _act(user: User) -> None: - app.dependency_overrides[current_active_user] = lambda: user + app.dependency_overrides[get_auth_context] = lambda: AuthContext.session(user) return _act diff --git a/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py b/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py index 22f6c6de5..d56c18420 100644 --- a/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py +++ b/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py @@ -28,6 +28,7 @@ from sqlalchemy import func, select, text from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.db import ( SearchSourceConnector, SearchSourceConnectorType, @@ -42,6 +43,7 @@ from app.routes.obsidian_plugin_routes import ( obsidian_stats, obsidian_sync, ) +from app.routes.search_spaces_routes import create_default_roles_and_membership from app.schemas.obsidian_plugin import ( ConnectRequest, DeleteAck, @@ -65,6 +67,10 @@ pytestmark = pytest.mark.integration # --------------------------------------------------------------------------- +def _auth(user: User) -> AuthContext: + return AuthContext.session(user) + + def _make_note_payload(vault_id: str, path: str, content_hash: str) -> NotePayload: """Minimal NotePayload that the schema accepts; the indexer is mocked out so the values don't have to round-trip through the real pipeline.""" @@ -102,6 +108,8 @@ async def race_user_and_space(async_engine): ) space = SearchSpace(name="Race Space", user_id=user_id) setup.add_all([user, space]) + await setup.flush() + await create_default_roles_and_membership(setup, space.id, user_id) await setup.commit() await setup.refresh(space) space_id = space.id @@ -116,6 +124,14 @@ async def race_user_and_space(async_engine): text("DELETE FROM search_source_connectors WHERE user_id = :uid"), {"uid": user_id}, ) + await cleanup.execute( + text("DELETE FROM search_space_memberships WHERE search_space_id = :id"), + {"id": space_id}, + ) + await cleanup.execute( + text("DELETE FROM search_space_roles WHERE search_space_id = :id"), + {"id": space_id}, + ) await cleanup.execute( text("DELETE FROM searchspaces WHERE id = :id"), {"id": space_id}, @@ -154,7 +170,7 @@ class TestConnectRace: search_space_id=space_id, vault_fingerprint=fingerprint, ) - await obsidian_connect(payload, user=fresh_user, session=s) + await obsidian_connect(payload, auth=_auth(fresh_user), session=s) results = await asyncio.gather(_call("a"), _call("b"), return_exceptions=True) for r in results: @@ -281,7 +297,7 @@ class TestConnectRace: search_space_id=space_id, vault_fingerprint=fingerprint, ), - user=fresh_user, + auth=_auth(fresh_user), session=s, ) @@ -294,7 +310,7 @@ class TestConnectRace: search_space_id=space_id, vault_fingerprint=fingerprint, ), - user=fresh_user, + auth=_auth(fresh_user), session=s, ) @@ -337,7 +353,7 @@ class TestWireContractSmoke: search_space_id=db_search_space.id, vault_fingerprint="fp-" + uuid.uuid4().hex, ), - user=db_user, + auth=_auth(db_user), session=db_session, ) assert connect_resp.connector_id > 0 @@ -361,7 +377,7 @@ class TestWireContractSmoke: _make_note_payload(vault_id, "fail.md", "hash-fail"), ], ), - user=db_user, + auth=_auth(db_user), session=db_session, ) @@ -394,7 +410,7 @@ class TestWireContractSmoke: _make_note_payload(vault_id, "fail.md", "h2"), ], ), - user=db_user, + auth=_auth(db_user), session=db_session, ) assert sync_resp.indexed == 1 @@ -420,7 +436,7 @@ class TestWireContractSmoke: RenameItem(old_path="missing.md", new_path="x.md"), ], ), - user=db_user, + auth=_auth(db_user), session=db_session, ) assert isinstance(rename_resp, RenameAck) @@ -441,7 +457,7 @@ class TestWireContractSmoke: ): delete_resp = await obsidian_delete_notes( DeleteBatchRequest(vault_id=vault_id, paths=["b.md", "ghost.md"]), - user=db_user, + auth=_auth(db_user), session=db_session, ) assert isinstance(delete_resp, DeleteAck) @@ -456,7 +472,7 @@ class TestWireContractSmoke: # upsert_note was mocked) but the response shape is what we care # about. manifest_resp = await obsidian_manifest( - vault_id=vault_id, user=db_user, session=db_session + vault_id=vault_id, auth=_auth(db_user), session=db_session ) assert isinstance(manifest_resp, ManifestResponse) assert manifest_resp.vault_id == vault_id @@ -464,7 +480,7 @@ class TestWireContractSmoke: # 6. /stats — same; row count is 0 because upsert_note was mocked. stats_resp = await obsidian_stats( - vault_id=vault_id, user=db_user, session=db_session + vault_id=vault_id, auth=_auth(db_user), session=db_session ) assert isinstance(stats_resp, StatsResponse) assert stats_resp.vault_id == vault_id @@ -482,7 +498,7 @@ class TestWireContractSmoke: search_space_id=db_search_space.id, vault_fingerprint="fp-" + uuid.uuid4().hex, ), - user=db_user, + auth=_auth(db_user), session=db_session, ) @@ -511,7 +527,7 @@ class TestWireContractSmoke: binary_note, ], ), - user=db_user, + auth=_auth(db_user), session=db_session, ) @@ -533,7 +549,7 @@ class TestWireContractSmoke: search_space_id=db_search_space.id, vault_fingerprint="fp-" + uuid.uuid4().hex, ), - user=db_user, + auth=_auth(db_user), session=db_session, ) @@ -562,7 +578,7 @@ class TestWireContractSmoke: bad_note, ], ), - user=db_user, + auth=_auth(db_user), session=db_session, ) @@ -587,7 +603,7 @@ class TestWireContractSmoke: search_space_id=db_search_space.id, vault_fingerprint="fp-" + uuid.uuid4().hex, ), - user=db_user, + auth=_auth(db_user), session=db_session, ) @@ -616,7 +632,7 @@ class TestWireContractSmoke: mismatched, ], ), - user=db_user, + auth=_auth(db_user), session=db_session, ) diff --git a/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py b/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py index c97dec6a2..477d927e2 100644 --- a/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py +++ b/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py @@ -15,6 +15,7 @@ import pytest from fastapi import HTTPException import app.automations.services.automation as automation_mod +from app.auth.context import AuthContext from app.automations.schemas.api import AutomationCreate, AutomationUpdate from app.automations.schemas.definition.envelope import ( AutomationDefinition, @@ -45,7 +46,8 @@ class _FakeSession: def _service(search_space: Any) -> AutomationService: return AutomationService( - session=_FakeSession(search_space), user=SimpleNamespace(id="u-1") + session=_FakeSession(search_space), + auth=AuthContext.session(SimpleNamespace(id="u-1")), ) diff --git a/surfsense_backend/tests/unit/gateway/test_webhook_routes.py b/surfsense_backend/tests/unit/gateway/test_webhook_routes.py index aa8bd3a89..354c3037d 100644 --- a/surfsense_backend/tests/unit/gateway/test_webhook_routes.py +++ b/surfsense_backend/tests/unit/gateway/test_webhook_routes.py @@ -9,6 +9,7 @@ from types import SimpleNamespace import pytest +from app.auth.context import AuthContext from app.db import ExternalChatAccount, ExternalChatAccountMode, ExternalChatPlatform from app.routes import gateway_webhook_routes as routes @@ -333,7 +334,9 @@ async def test_discord_gateway_install_returns_oauth_url(monkeypatch, mocker): response = await routes.install_discord_gateway( search_space_id=123, - user=SimpleNamespace(id="00000000-0000-0000-0000-000000000001"), + auth=AuthContext.session( + SimpleNamespace(id="00000000-0000-0000-0000-000000000001") + ), session=mocker.AsyncMock(), )