diff --git a/surfsense_backend/alembic/versions/168_harden_refresh_token_schema.py b/surfsense_backend/alembic/versions/168_harden_refresh_token_schema.py index fc14c8d73..1e902ea58 100644 --- a/surfsense_backend/alembic/versions/168_harden_refresh_token_schema.py +++ b/surfsense_backend/alembic/versions/168_harden_refresh_token_schema.py @@ -17,35 +17,49 @@ depends_on: str | Sequence[str] | None = None def upgrade() -> None: - op.add_column( - "refresh_tokens", - sa.Column("revoked_at", sa.TIMESTAMP(timezone=True), nullable=True), - ) - op.add_column( - "refresh_tokens", - sa.Column("absolute_expiry", sa.TIMESTAMP(timezone=True), nullable=True), + op.execute( + "ALTER TABLE refresh_tokens ADD COLUMN IF NOT EXISTS " + "revoked_at TIMESTAMP WITH TIME ZONE" ) op.execute( - """ - UPDATE refresh_tokens - SET revoked_at = NOW() - WHERE is_revoked = TRUE - """ + "ALTER TABLE refresh_tokens ADD COLUMN IF NOT EXISTS " + "absolute_expiry TIMESTAMP WITH TIME ZONE" ) - op.alter_column( - "refresh_tokens", - "token_hash", - existing_type=sa.String(length=256), - type_=sa.String(length=64), - existing_nullable=False, + + bind = op.get_bind() + is_revoked_exists = bind.execute( + sa.text( + """ + SELECT EXISTS ( + SELECT FROM information_schema.columns + WHERE table_schema = current_schema() + AND table_name = 'refresh_tokens' + AND column_name = 'is_revoked' + ) + """ + ) + ).scalar() + + if is_revoked_exists: + op.execute( + """ + UPDATE refresh_tokens + SET revoked_at = NOW() + WHERE is_revoked = TRUE + AND revoked_at IS NULL + """ + ) + + op.execute( + "ALTER TABLE refresh_tokens ALTER COLUMN token_hash TYPE VARCHAR(64)" ) - op.drop_column("refresh_tokens", "is_revoked") + op.execute("ALTER TABLE refresh_tokens DROP COLUMN IF EXISTS is_revoked") def downgrade() -> None: - op.add_column( - "refresh_tokens", - sa.Column("is_revoked", sa.Boolean(), nullable=False, server_default="false"), + op.execute( + "ALTER TABLE refresh_tokens ADD COLUMN IF NOT EXISTS " + "is_revoked BOOLEAN NOT NULL DEFAULT false" ) op.execute( """ @@ -54,13 +68,9 @@ def downgrade() -> None: WHERE revoked_at IS NOT NULL """ ) - op.alter_column("refresh_tokens", "is_revoked", server_default=None) - op.alter_column( - "refresh_tokens", - "token_hash", - existing_type=sa.String(length=64), - type_=sa.String(length=256), - existing_nullable=False, + op.execute("ALTER TABLE refresh_tokens ALTER COLUMN is_revoked DROP DEFAULT") + op.execute( + "ALTER TABLE refresh_tokens ALTER COLUMN token_hash TYPE VARCHAR(256)" ) - op.drop_column("refresh_tokens", "absolute_expiry") - op.drop_column("refresh_tokens", "revoked_at") + op.execute("ALTER TABLE refresh_tokens DROP COLUMN IF EXISTS absolute_expiry") + op.execute("ALTER TABLE refresh_tokens DROP COLUMN IF EXISTS revoked_at")