refactor(auth): replace user variable with auth context in integration and unit tests

This commit is contained in:
Anish Sarkar 2026-06-20 03:11:00 +05:30
parent 14cb0a22e9
commit af5a112212
9 changed files with 86 additions and 56 deletions

View file

@ -42,7 +42,7 @@ Maximize logic covered by unit tests; keep integration tests for what genuinely
`conftest.py` is scoped to its directory and below. Keep truly global fixtures in `tests/conftest.py`; put module-specific fixtures in that module's `conftest.py` so a DB fixture never loads for a pure unit test.
For API integration tests, override `get_async_session` and `current_active_user` to ride the test's transactional `db_session` (see `tests/integration/notifications/conftest.py`): rows seeded in the test and rows read via the endpoint share one transaction that rolls back automatically.
For API integration tests, override `get_async_session` and `get_auth_context` to ride the test's transactional `db_session` (see `tests/integration/notifications/conftest.py`): rows seeded in the test and rows read via the endpoint share one transaction that rolls back automatically.
## Import mode

View file

@ -40,6 +40,7 @@ import pytest_asyncio
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import (
ChatVisibility,
NewChatMessage,
@ -395,7 +396,7 @@ class TestAppendMessageRecoveryAfterFinalize:
thread_id=thread_id,
request=request,
session=db_session,
user=db_user,
auth=AuthContext.session(db_user),
)
# Response must echo the SERVER's rich payload, not the FE's
@ -469,7 +470,7 @@ class TestAppendMessageRecoveryAfterFinalize:
thread_id=thread_id,
request=_FakeRequest(fe_request_body),
session=db_session,
user=db_user,
auth=AuthContext.session(db_user),
)
assert fe_response.role == NewChatMessageRole.ASSISTANT
@ -552,7 +553,7 @@ class TestAppendMessageRecoveryAfterFinalize:
}
),
session=db_session,
user=db_user,
auth=AuthContext.session(db_user),
)
assert ok_response.role == NewChatMessageRole.USER
assert ok_response.turn_id is None

View file

@ -16,6 +16,7 @@ from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import (
ChatVisibility,
SearchSpace,
@ -33,6 +34,10 @@ from app.schemas.new_chat import (
pytestmark = pytest.mark.integration
def _auth(user: User) -> AuthContext:
return AuthContext.session(user)
@pytest_asyncio.fixture
async def db_member(db_session: AsyncSession, db_search_space: SearchSpace) -> User:
member = User(
@ -85,7 +90,7 @@ async def _create_thread(
visibility=ChatVisibility.PRIVATE,
),
session=db_session,
user=db_user,
auth=_auth(db_user),
)
@ -108,13 +113,13 @@ async def test_private_thread_is_hidden_from_other_search_space_member(
member_threads = await new_chat_routes.list_threads(
search_space_id=db_search_space.id,
session=db_session,
user=db_member,
auth=_auth(db_member),
)
member_search = await new_chat_routes.search_threads(
search_space_id=db_search_space.id,
title="Visibility",
session=db_session,
user=db_member,
auth=_auth(db_member),
)
assert thread.id not in _active_thread_ids(member_threads)
@ -123,7 +128,7 @@ async def test_private_thread_is_hidden_from_other_search_space_member(
await new_chat_routes.get_thread_full(
thread_id=thread.id,
session=db_session,
user=db_member,
auth=_auth(db_member),
)
assert exc_info.value.status_code == 403
@ -142,24 +147,24 @@ async def test_creator_can_share_thread_and_member_can_list_search_read_it(
visibility=ChatVisibility.SEARCH_SPACE,
),
session=db_session,
user=db_user,
auth=_auth(db_user),
)
member_threads = await new_chat_routes.list_threads(
search_space_id=db_search_space.id,
session=db_session,
user=db_member,
auth=_auth(db_member),
)
member_search = await new_chat_routes.search_threads(
search_space_id=db_search_space.id,
title="Visibility",
session=db_session,
user=db_member,
auth=_auth(db_member),
)
full_thread = await new_chat_routes.get_thread_full(
thread_id=thread.id,
session=db_session,
user=db_member,
auth=_auth(db_member),
)
assert updated.visibility == ChatVisibility.SEARCH_SPACE
@ -181,20 +186,20 @@ async def test_rename_and_archive_do_not_reset_shared_visibility(
visibility=ChatVisibility.SEARCH_SPACE,
),
session=db_session,
user=db_user,
auth=_auth(db_user),
)
renamed = await new_chat_routes.update_thread(
thread_id=thread.id,
thread_update=NewChatThreadUpdate(title="Renamed Shared Chat"),
session=db_session,
user=db_user,
auth=_auth(db_user),
)
archived = await new_chat_routes.update_thread(
thread_id=thread.id,
thread_update=NewChatThreadUpdate(archived=True),
session=db_session,
user=db_user,
auth=_auth(db_user),
)
assert renamed.visibility == ChatVisibility.SEARCH_SPACE
@ -215,7 +220,7 @@ async def test_non_creator_cannot_change_shared_thread_back_to_private(
visibility=ChatVisibility.SEARCH_SPACE,
),
session=db_session,
user=db_user,
auth=_auth(db_user),
)
with pytest.raises(HTTPException) as exc_info:
@ -225,7 +230,7 @@ async def test_non_creator_cannot_change_shared_thread_back_to_private(
visibility=ChatVisibility.PRIVATE,
),
session=db_session,
user=db_member,
auth=_auth(db_member),
)
assert exc_info.value.status_code == 403
@ -244,7 +249,7 @@ async def test_creator_can_make_shared_thread_private_again(
visibility=ChatVisibility.SEARCH_SPACE,
),
session=db_session,
user=db_user,
auth=_auth(db_user),
)
private_again = await new_chat_routes.update_thread_visibility(
@ -253,18 +258,18 @@ async def test_creator_can_make_shared_thread_private_again(
visibility=ChatVisibility.PRIVATE,
),
session=db_session,
user=db_user,
auth=_auth(db_user),
)
member_threads = await new_chat_routes.list_threads(
search_space_id=db_search_space.id,
session=db_session,
user=db_member,
auth=_auth(db_member),
)
member_search = await new_chat_routes.search_threads(
search_space_id=db_search_space.id,
title="Visibility",
session=db_session,
user=db_member,
auth=_auth(db_member),
)
assert private_again.visibility == ChatVisibility.PRIVATE
@ -274,6 +279,6 @@ async def test_creator_can_make_shared_thread_private_again(
await new_chat_routes.get_thread_full(
thread_id=thread.id,
session=db_session,
user=db_member,
auth=_auth(db_member),
)
assert exc_info.value.status_code == 403

View file

@ -15,6 +15,7 @@ from httpx import ASGITransport
from sqlalchemy.ext.asyncio import AsyncSession
from app.app import app, limiter
from app.auth.context import AuthContext
from app.config import config
from app.db import (
SearchSourceConnector,
@ -22,7 +23,7 @@ from app.db import (
User,
get_async_session,
)
from app.users import current_active_user
from app.users import get_auth_context
pytestmark = pytest.mark.integration
@ -40,12 +41,12 @@ async def client(
async def override_session() -> AsyncGenerator[AsyncSession, None]:
yield db_session
async def override_user() -> User:
return db_user
async def override_auth() -> AuthContext:
return AuthContext.session(db_user)
previous_overrides = app.dependency_overrides.copy()
app.dependency_overrides[get_async_session] = override_session
app.dependency_overrides[current_active_user] = override_user
app.dependency_overrides[get_auth_context] = override_auth
try:
async with httpx.AsyncClient(

View file

@ -1,9 +1,9 @@
"""Notifications integration fixtures.
The app's DB session and current-user dependencies are overridden to ride the
The app's DB session and auth-context dependencies are overridden to ride the
test's transactional `db_session`, so API calls and seeded rows share one
transaction that rolls back per test. Overriding `current_active_user` also
bypasses real JWT auth, so these tests don't depend on AUTH_TYPE.
transaction that rolls back per test. Overriding `get_auth_context` also bypasses
real JWT auth, so these tests don't depend on AUTH_TYPE.
"""
from __future__ import annotations
@ -17,8 +17,9 @@ from httpx import ASGITransport
from sqlalchemy.ext.asyncio import AsyncSession
from app.app import app, limiter
from app.auth.context import AuthContext
from app.db import User, get_async_session
from app.users import current_active_user
from app.users import get_auth_context
pytestmark = pytest.mark.integration
@ -33,12 +34,12 @@ async def client(
async def override_session() -> AsyncGenerator[AsyncSession, None]:
yield db_session
async def override_user() -> User:
return db_user
async def override_auth() -> AuthContext:
return AuthContext.session(db_user)
previous_overrides = app.dependency_overrides.copy()
app.dependency_overrides[get_async_session] = override_session
app.dependency_overrides[current_active_user] = override_user
app.dependency_overrides[get_auth_context] = override_auth
try:
async with httpx.AsyncClient(

View file

@ -24,6 +24,7 @@ from httpx import ASGITransport
from sqlalchemy.ext.asyncio import AsyncSession
from app.app import app, limiter
from app.auth.context import AuthContext
from app.config import config as app_config
from app.db import SearchSpace, User, get_async_session
from app.podcasts.persistence import Podcast, PodcastStatus
@ -39,7 +40,7 @@ from app.podcasts.schemas import (
from app.podcasts.service import PodcastService
from app.podcasts.tts import SynthesisRequest, SynthesizedAudio, TextToSpeech
from app.routes.search_spaces_routes import create_default_roles_and_membership
from app.users import current_active_user
from app.users import get_auth_context
pytestmark = pytest.mark.integration
@ -54,12 +55,12 @@ async def client(
async def override_session() -> AsyncGenerator[AsyncSession, None]:
yield db_session
async def override_user() -> User:
return db_user
async def override_auth() -> AuthContext:
return AuthContext.session(db_user)
previous_overrides = app.dependency_overrides.copy()
app.dependency_overrides[get_async_session] = override_session
app.dependency_overrides[current_active_user] = override_user
app.dependency_overrides[get_auth_context] = override_auth
try:
async with httpx.AsyncClient(
@ -290,7 +291,7 @@ def act_as():
"""
def _act(user: User) -> None:
app.dependency_overrides[current_active_user] = lambda: user
app.dependency_overrides[get_auth_context] = lambda: AuthContext.session(user)
return _act

View file

@ -28,6 +28,7 @@ from sqlalchemy import func, select, text
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
@ -42,6 +43,7 @@ from app.routes.obsidian_plugin_routes import (
obsidian_stats,
obsidian_sync,
)
from app.routes.search_spaces_routes import create_default_roles_and_membership
from app.schemas.obsidian_plugin import (
ConnectRequest,
DeleteAck,
@ -65,6 +67,10 @@ pytestmark = pytest.mark.integration
# ---------------------------------------------------------------------------
def _auth(user: User) -> AuthContext:
return AuthContext.session(user)
def _make_note_payload(vault_id: str, path: str, content_hash: str) -> NotePayload:
"""Minimal NotePayload that the schema accepts; the indexer is mocked
out so the values don't have to round-trip through the real pipeline."""
@ -102,6 +108,8 @@ async def race_user_and_space(async_engine):
)
space = SearchSpace(name="Race Space", user_id=user_id)
setup.add_all([user, space])
await setup.flush()
await create_default_roles_and_membership(setup, space.id, user_id)
await setup.commit()
await setup.refresh(space)
space_id = space.id
@ -116,6 +124,14 @@ async def race_user_and_space(async_engine):
text("DELETE FROM search_source_connectors WHERE user_id = :uid"),
{"uid": user_id},
)
await cleanup.execute(
text("DELETE FROM search_space_memberships WHERE search_space_id = :id"),
{"id": space_id},
)
await cleanup.execute(
text("DELETE FROM search_space_roles WHERE search_space_id = :id"),
{"id": space_id},
)
await cleanup.execute(
text("DELETE FROM searchspaces WHERE id = :id"),
{"id": space_id},
@ -154,7 +170,7 @@ class TestConnectRace:
search_space_id=space_id,
vault_fingerprint=fingerprint,
)
await obsidian_connect(payload, user=fresh_user, session=s)
await obsidian_connect(payload, auth=_auth(fresh_user), session=s)
results = await asyncio.gather(_call("a"), _call("b"), return_exceptions=True)
for r in results:
@ -281,7 +297,7 @@ class TestConnectRace:
search_space_id=space_id,
vault_fingerprint=fingerprint,
),
user=fresh_user,
auth=_auth(fresh_user),
session=s,
)
@ -294,7 +310,7 @@ class TestConnectRace:
search_space_id=space_id,
vault_fingerprint=fingerprint,
),
user=fresh_user,
auth=_auth(fresh_user),
session=s,
)
@ -337,7 +353,7 @@ class TestWireContractSmoke:
search_space_id=db_search_space.id,
vault_fingerprint="fp-" + uuid.uuid4().hex,
),
user=db_user,
auth=_auth(db_user),
session=db_session,
)
assert connect_resp.connector_id > 0
@ -361,7 +377,7 @@ class TestWireContractSmoke:
_make_note_payload(vault_id, "fail.md", "hash-fail"),
],
),
user=db_user,
auth=_auth(db_user),
session=db_session,
)
@ -394,7 +410,7 @@ class TestWireContractSmoke:
_make_note_payload(vault_id, "fail.md", "h2"),
],
),
user=db_user,
auth=_auth(db_user),
session=db_session,
)
assert sync_resp.indexed == 1
@ -420,7 +436,7 @@ class TestWireContractSmoke:
RenameItem(old_path="missing.md", new_path="x.md"),
],
),
user=db_user,
auth=_auth(db_user),
session=db_session,
)
assert isinstance(rename_resp, RenameAck)
@ -441,7 +457,7 @@ class TestWireContractSmoke:
):
delete_resp = await obsidian_delete_notes(
DeleteBatchRequest(vault_id=vault_id, paths=["b.md", "ghost.md"]),
user=db_user,
auth=_auth(db_user),
session=db_session,
)
assert isinstance(delete_resp, DeleteAck)
@ -456,7 +472,7 @@ class TestWireContractSmoke:
# upsert_note was mocked) but the response shape is what we care
# about.
manifest_resp = await obsidian_manifest(
vault_id=vault_id, user=db_user, session=db_session
vault_id=vault_id, auth=_auth(db_user), session=db_session
)
assert isinstance(manifest_resp, ManifestResponse)
assert manifest_resp.vault_id == vault_id
@ -464,7 +480,7 @@ class TestWireContractSmoke:
# 6. /stats — same; row count is 0 because upsert_note was mocked.
stats_resp = await obsidian_stats(
vault_id=vault_id, user=db_user, session=db_session
vault_id=vault_id, auth=_auth(db_user), session=db_session
)
assert isinstance(stats_resp, StatsResponse)
assert stats_resp.vault_id == vault_id
@ -482,7 +498,7 @@ class TestWireContractSmoke:
search_space_id=db_search_space.id,
vault_fingerprint="fp-" + uuid.uuid4().hex,
),
user=db_user,
auth=_auth(db_user),
session=db_session,
)
@ -511,7 +527,7 @@ class TestWireContractSmoke:
binary_note,
],
),
user=db_user,
auth=_auth(db_user),
session=db_session,
)
@ -533,7 +549,7 @@ class TestWireContractSmoke:
search_space_id=db_search_space.id,
vault_fingerprint="fp-" + uuid.uuid4().hex,
),
user=db_user,
auth=_auth(db_user),
session=db_session,
)
@ -562,7 +578,7 @@ class TestWireContractSmoke:
bad_note,
],
),
user=db_user,
auth=_auth(db_user),
session=db_session,
)
@ -587,7 +603,7 @@ class TestWireContractSmoke:
search_space_id=db_search_space.id,
vault_fingerprint="fp-" + uuid.uuid4().hex,
),
user=db_user,
auth=_auth(db_user),
session=db_session,
)
@ -616,7 +632,7 @@ class TestWireContractSmoke:
mismatched,
],
),
user=db_user,
auth=_auth(db_user),
session=db_session,
)

View file

@ -15,6 +15,7 @@ import pytest
from fastapi import HTTPException
import app.automations.services.automation as automation_mod
from app.auth.context import AuthContext
from app.automations.schemas.api import AutomationCreate, AutomationUpdate
from app.automations.schemas.definition.envelope import (
AutomationDefinition,
@ -45,7 +46,8 @@ class _FakeSession:
def _service(search_space: Any) -> AutomationService:
return AutomationService(
session=_FakeSession(search_space), user=SimpleNamespace(id="u-1")
session=_FakeSession(search_space),
auth=AuthContext.session(SimpleNamespace(id="u-1")),
)

View file

@ -9,6 +9,7 @@ from types import SimpleNamespace
import pytest
from app.auth.context import AuthContext
from app.db import ExternalChatAccount, ExternalChatAccountMode, ExternalChatPlatform
from app.routes import gateway_webhook_routes as routes
@ -333,7 +334,9 @@ async def test_discord_gateway_install_returns_oauth_url(monkeypatch, mocker):
response = await routes.install_discord_gateway(
search_space_id=123,
user=SimpleNamespace(id="00000000-0000-0000-0000-000000000001"),
auth=AuthContext.session(
SimpleNamespace(id="00000000-0000-0000-0000-000000000001")
),
session=mocker.AsyncMock(),
)