feat(story-5.3): add Stripe webhook subscription lifecycle handlers

- Add migration 125: subscription_current_period_end column
- Add PLAN_LIMITS config (free/pro_monthly/pro_yearly token + pages limits)
- Add subscription webhook handlers: created/updated/deleted, invoice payment
- Handle checkout.session.completed for subscription mode separately from PAYG
- Idempotency: subscription_id + status + plan_id + period_end guard
- pages_limit upgraded on activation, gracefully downgraded on cancel
- Token reset on subscription_create and subscription_cycle billing events
- Period_end forward-only guard against out-of-order webhook delivery

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Vonic 2026-04-15 00:43:07 +07:00
parent 07a4bc3fc3
commit 20c4f128bb
7 changed files with 381 additions and 27 deletions

View file

@ -0,0 +1,35 @@
"""125_add_subscription_current_period_end
Revision ID: 125
Revises: 124
Create Date: 2026-04-15
Adds subscription_current_period_end column to the user table for
tracking when the current billing period ends (Story 5.3).
Column added:
- subscription_current_period_end (TIMESTAMP with timezone, nullable)
"""
from __future__ import annotations
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "125"
down_revision: str | None = "124"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column("subscription_current_period_end", sa.TIMESTAMP(timezone=True), nullable=True),
)
def downgrade() -> None:
op.drop_column("user", "subscription_current_period_end")

View file

@ -310,6 +310,13 @@ class Config:
os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100")
)
# Subscription plan limits
PLAN_LIMITS: dict[str, dict[str, int]] = {
"free": {"monthly_token_limit": 50_000, "pages_limit": 500},
"pro_monthly": {"monthly_token_limit": 1_000_000, "pages_limit": 5000},
"pro_yearly": {"monthly_token_limit": 1_000_000, "pages_limit": 5000},
}
# Auth
AUTH_TYPE = os.getenv("AUTH_TYPE")
REGISTRATION_ENABLED = os.getenv("REGISTRATION_ENABLED", "TRUE").upper() == "TRUE"

View file

@ -1976,6 +1976,7 @@ if config.AUTH_TYPE == "GOOGLE":
plan_id = Column(String(50), nullable=False, default="free", server_default="free")
stripe_customer_id = Column(String(255), nullable=True, unique=True)
stripe_subscription_id = Column(String(255), nullable=True, unique=True)
subscription_current_period_end = Column(TIMESTAMP(timezone=True), nullable=True)
# User profile from OAuth
display_name = Column(String, nullable=True)
@ -2104,6 +2105,7 @@ else:
plan_id = Column(String(50), nullable=False, default="free", server_default="free")
stripe_customer_id = Column(String(255), nullable=True, unique=True)
stripe_subscription_id = Column(String(255), nullable=True, unique=True)
subscription_current_period_end = Column(TIMESTAMP(timezone=True), nullable=True)
# User profile (can be set manually for non-OAuth users)
display_name = Column(String, nullable=True)

View file

@ -284,6 +284,243 @@ async def _fulfill_completed_purchase(
return StripeWebhookResponse()
# ---------------------------------------------------------------------------
# Subscription event helpers
# ---------------------------------------------------------------------------
async def _get_user_by_stripe_customer_id(
db_session: AsyncSession, customer_id: str
) -> User | None:
"""Fetch the User row for a given Stripe customer ID (with FOR UPDATE lock)."""
return (
(
await db_session.execute(
select(User)
.where(User.stripe_customer_id == customer_id)
.with_for_update(of=User)
)
)
.unique()
.scalar_one_or_none()
)
def _period_end_from_subscription(subscription: Any) -> datetime | None:
"""Extract current_period_end timestamp from a Stripe subscription object."""
ts = getattr(subscription, "current_period_end", None)
if ts is None:
return None
return datetime.fromtimestamp(int(ts), tz=UTC)
async def _handle_subscription_event(
db_session: AsyncSession, subscription: Any
) -> StripeWebhookResponse:
"""Handle customer.subscription.created / updated / deleted.
Idempotency: compares stripe_subscription_id + current_period_end so
duplicate events for the same billing period are no-ops.
"""
customer_id = _normalize_optional_string(getattr(subscription, "customer", None))
subscription_id = _normalize_optional_string(getattr(subscription, "id", None))
sub_status = str(getattr(subscription, "status", "")).lower()
period_end = _period_end_from_subscription(subscription)
# Determine plan from the first subscription item's price ID
plan_id: str = "free"
try:
items = getattr(subscription, "items", None)
if items:
item_data = getattr(items, "data", None) or []
if item_data:
price_id = str(getattr(item_data[0].price, "id", ""))
if price_id == config.STRIPE_PRO_YEARLY_PRICE_ID:
plan_id = "pro_yearly"
elif price_id == config.STRIPE_PRO_MONTHLY_PRICE_ID:
plan_id = "pro_monthly"
else:
logger.warning(
"Subscription %s has unrecognized price ID %s; defaulting to free limits",
subscription_id,
price_id,
)
except Exception: # noqa: BLE001
logger.warning("Could not parse plan from subscription %s", subscription_id)
if not customer_id:
logger.error("Subscription event missing customer ID for subscription %s", subscription_id)
return StripeWebhookResponse()
user = await _get_user_by_stripe_customer_id(db_session, customer_id)
if user is None:
logger.warning("No user found for Stripe customer %s; skipping subscription event", customer_id)
return StripeWebhookResponse()
# Map Stripe status → SubscriptionStatus enum
if sub_status == "active":
new_status = SubscriptionStatus.ACTIVE
elif sub_status in {"canceled", "incomplete_expired"}:
new_status = SubscriptionStatus.CANCELED
plan_id = "free"
elif sub_status == "past_due":
new_status = SubscriptionStatus.PAST_DUE
else:
# incomplete, trialing, unpaid → leave current status unchanged
logger.info(
"Ignoring subscription %s with unhandled Stripe status '%s'",
subscription_id,
sub_status,
)
return StripeWebhookResponse()
# Idempotency: skip if nothing meaningful changed
if (
user.stripe_subscription_id == subscription_id
and user.subscription_status == new_status
and user.plan_id == plan_id
and user.subscription_current_period_end == period_end
):
logger.info("Subscription %s already up-to-date; skipping", subscription_id)
return StripeWebhookResponse()
# Update subscription fields
user.stripe_subscription_id = subscription_id
user.subscription_status = new_status
user.plan_id = plan_id
# Guard against out-of-order webhook delivery: only advance period_end forward
if period_end is not None and (
user.subscription_current_period_end is None
or period_end > user.subscription_current_period_end
):
user.subscription_current_period_end = period_end
# Update limits from plan config
limits = config.PLAN_LIMITS.get(plan_id, config.PLAN_LIMITS["free"])
user.monthly_token_limit = limits["monthly_token_limit"]
# Upgrade pages_limit on activation
if new_status == SubscriptionStatus.ACTIVE:
user.pages_limit = max(user.pages_used, limits["pages_limit"])
# Downgrade pages_limit when canceling
if new_status == SubscriptionStatus.CANCELED:
free_limits = config.PLAN_LIMITS["free"]
user.pages_limit = max(user.pages_used, free_limits["pages_limit"])
logger.info(
"Updated subscription for user %s: status=%s plan=%s subscription=%s",
user.id,
new_status,
plan_id,
subscription_id,
)
await db_session.commit()
return StripeWebhookResponse()
async def _handle_invoice_payment_succeeded(
db_session: AsyncSession, invoice: Any
) -> StripeWebhookResponse:
"""Reset tokens_used_this_month and advance token_reset_date on billing renewal."""
customer_id = _normalize_optional_string(getattr(invoice, "customer", None))
billing_reason = str(getattr(invoice, "billing_reason", "")).lower()
if not customer_id:
return StripeWebhookResponse()
# Reset tokens on subscription renewals and initial subscription creation
if billing_reason not in {"subscription_cycle", "subscription_create"}:
logger.info("invoice.payment_succeeded billing_reason=%s; not resetting tokens", billing_reason)
return StripeWebhookResponse()
user = await _get_user_by_stripe_customer_id(db_session, customer_id)
if user is None:
logger.warning("No user found for Stripe customer %s; skipping token reset", customer_id)
return StripeWebhookResponse()
user.tokens_used_this_month = 0
user.token_reset_date = datetime.now(UTC).date()
logger.info("Reset tokens_used_this_month for user %s on subscription renewal", user.id)
await db_session.commit()
return StripeWebhookResponse()
async def _handle_invoice_payment_failed(
db_session: AsyncSession, invoice: Any
) -> StripeWebhookResponse:
"""Mark subscription as past_due when a renewal invoice payment fails."""
customer_id = _normalize_optional_string(getattr(invoice, "customer", None))
if not customer_id:
return StripeWebhookResponse()
user = await _get_user_by_stripe_customer_id(db_session, customer_id)
if user is None:
logger.warning("No user found for Stripe customer %s; skipping past_due update", customer_id)
return StripeWebhookResponse()
if user.subscription_status == SubscriptionStatus.ACTIVE:
user.subscription_status = SubscriptionStatus.PAST_DUE
logger.info("Set subscription to PAST_DUE for user %s", user.id)
await db_session.commit()
else:
logger.info("invoice.payment_failed for user %s already in status %s; no change", user.id, user.subscription_status)
return StripeWebhookResponse()
async def _activate_subscription_from_checkout(
db_session: AsyncSession, checkout_session: Any
) -> StripeWebhookResponse:
"""Activate subscription when checkout.session.completed fires for mode='subscription'.
The full subscription lifecycle will also be handled by customer.subscription.created,
but we activate immediately here so the user sees Pro access right after checkout.
"""
customer_id = _normalize_optional_string(getattr(checkout_session, "customer", None))
subscription_id = _normalize_optional_string(getattr(checkout_session, "subscription", None))
metadata = _get_metadata(checkout_session)
plan_id_str = metadata.get("plan_id", "")
if not customer_id:
logger.error("Subscription checkout session missing customer ID: %s", getattr(checkout_session, "id", ""))
return StripeWebhookResponse()
user = await _get_user_by_stripe_customer_id(db_session, customer_id)
if user is None:
logger.warning("No user found for Stripe customer %s; skipping subscription activation", customer_id)
return StripeWebhookResponse()
# Idempotency: already activated
if user.subscription_status == SubscriptionStatus.ACTIVE and user.stripe_subscription_id == subscription_id:
logger.info("Subscription already active for user %s; skipping activation", user.id)
return StripeWebhookResponse()
plan_id = plan_id_str if plan_id_str in {"pro_monthly", "pro_yearly"} else "pro_monthly"
limits = config.PLAN_LIMITS.get(plan_id, config.PLAN_LIMITS["pro_monthly"])
user.subscription_status = SubscriptionStatus.ACTIVE
user.plan_id = plan_id
user.stripe_subscription_id = subscription_id
user.monthly_token_limit = limits["monthly_token_limit"]
user.pages_limit = max(user.pages_used, limits["pages_limit"])
user.tokens_used_this_month = 0
user.token_reset_date = datetime.now(UTC).date()
# Retrieve subscription object to set period_end (best-effort)
if subscription_id:
try:
stripe_client = get_stripe_client()
sub_obj = stripe_client.v1.subscriptions.retrieve(subscription_id)
user.subscription_current_period_end = _period_end_from_subscription(sub_obj)
except Exception: # noqa: BLE001
logger.warning("Could not retrieve subscription %s for period_end", subscription_id)
logger.info("Activated subscription for user %s: plan=%s subscription=%s", user.id, plan_id, subscription_id)
await db_session.commit()
return StripeWebhookResponse()
@router.post("/create-checkout-session", response_model=CreateCheckoutSessionResponse)
async def create_checkout_session(
body: CreateCheckoutSessionRequest,
@ -484,12 +721,16 @@ async def stripe_webhook(
detail="Invalid Stripe webhook signature.",
) from exc
logger.info("Received Stripe webhook event: %s", event.type)
# --- Checkout session events ---
if event.type in {
"checkout.session.completed",
"checkout.session.async_payment_succeeded",
}:
checkout_session = event.data.object
payment_status = getattr(checkout_session, "payment_status", None)
session_mode = str(getattr(checkout_session, "mode", "payment")).lower()
if event.type == "checkout.session.completed" and payment_status not in {
"paid",
@ -501,6 +742,9 @@ async def stripe_webhook(
)
return StripeWebhookResponse()
if session_mode == "subscription":
return await _activate_subscription_from_checkout(db_session, checkout_session)
return await _fulfill_completed_purchase(db_session, checkout_session)
if event.type in {
@ -508,8 +752,30 @@ async def stripe_webhook(
"checkout.session.expired",
}:
checkout_session = event.data.object
return await _mark_purchase_failed(db_session, str(checkout_session.id))
# Only PAYG purchases have a PagePurchase row; subscription sessions are ignored here.
if str(getattr(checkout_session, "mode", "payment")).lower() != "subscription":
return await _mark_purchase_failed(db_session, str(checkout_session.id))
return StripeWebhookResponse()
# --- Subscription lifecycle events ---
if event.type in {
"customer.subscription.created",
"customer.subscription.updated",
"customer.subscription.deleted",
}:
subscription = event.data.object
return await _handle_subscription_event(db_session, subscription)
# --- Invoice events ---
if event.type == "invoice.payment_succeeded":
invoice = event.data.object
return await _handle_invoice_payment_succeeded(db_session, invoice)
if event.type == "invoice.payment_failed":
invoice = event.data.object
return await _handle_invoice_payment_failed(db_session, invoice)
logger.info("Unhandled Stripe event type: %s", event.type)
return StripeWebhookResponse()