Epic 5 Complete: Billing, Subscriptions, and Admin Features

Resolve all 5 deferred items from Epic 5 adversarial code review:
- Migration 124: Add CASCADE to subscriptionstatus enum drop (prevent orphaned references)
- Stripe rate limiting: In-memory per-user limiter (20 calls/60s) on verify-checkout-session
- Subscription request cooldown: 24h cooldown before resubmitting rejected requests
- Token reset date: Initialize on first subscription activation
- Checkout URL validation: Confirmed HTTPS-only (Stripe always returns HTTPS)

Implement Story 5.4 (Usage Tracking & Rate Limit Enforcement):
- Page quota pre-check at HTTP upload layer
- Extend UserRead schema with token quota fields
- Frontend 402 error handling in document upload
- Quota indicator in dashboard sidebar

Story 5.5 (Admin Seed & Approval Flow):
- Seed admin user migration with default credentials warning
- Subscription approval/rejection routes with admin guard
- 24h rejection cooldown enforcement

Story 5.6 (Admin-Only Model Config):
- Global model config visible across all search spaces
- Per-search-space model configs with user access control
- Superuser CRUD for global configs

Additional fixes from code review:
- PageLimitService: PAST_DUE subscriptions enforce free-tier limits
- TokenQuotaService: PAST_DUE subscriptions enforce free-tier limits
- Config routes: Fixed user_id.is_(None) filter on mutation endpoints
- Stripe webhook: Added guard against silent plan downgrade on unrecognized price_id

All changes formatted with Ruff (Python) and Biome (TypeScript).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Vonic 2026-04-15 03:54:45 +07:00
parent 20c4f128bb
commit 4eb6ed18d6
41 changed files with 1771 additions and 318 deletions

View file

@ -24,6 +24,7 @@ from __future__ import annotations
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "124"
@ -33,17 +34,35 @@ depends_on: str | Sequence[str] | None = None
# Create the enum type so SQLAlchemy's create_type=False works at runtime
subscriptionstatus_enum = sa.Enum(
"free", "active", "canceled", "past_due",
"free",
"active",
"canceled",
"past_due",
name="subscriptionstatus",
)
def upgrade() -> None:
# Create the PostgreSQL enum type first
subscriptionstatus_enum.create(op.get_bind(), checkfirst=True)
# Drop any pre-existing subscriptionstatus enum (e.g. uppercase version created by
# SQLAlchemy's create_all() during early development) so we can create it with
# the correct lowercase values. Safe to drop here because no column uses it yet.
conn = op.get_bind()
conn.execute(sa.text("DROP TYPE IF EXISTS subscriptionstatus CASCADE"))
# Create the PostgreSQL enum type with lowercase values
subscriptionstatus_enum.create(conn, checkfirst=False)
op.add_column("user", sa.Column("monthly_token_limit", sa.Integer(), nullable=False, server_default="100000"))
op.add_column("user", sa.Column("tokens_used_this_month", sa.Integer(), nullable=False, server_default="0"))
op.add_column(
"user",
sa.Column(
"monthly_token_limit", sa.Integer(), nullable=False, server_default="100000"
),
)
op.add_column(
"user",
sa.Column(
"tokens_used_this_month", sa.Integer(), nullable=False, server_default="0"
),
)
op.add_column("user", sa.Column("token_reset_date", sa.Date(), nullable=True))
op.add_column(
"user",
@ -54,12 +73,23 @@ def upgrade() -> None:
server_default="free",
),
)
op.add_column("user", sa.Column("plan_id", sa.String(50), nullable=False, server_default="free"))
op.add_column("user", sa.Column("stripe_customer_id", sa.String(255), nullable=True))
op.add_column("user", sa.Column("stripe_subscription_id", sa.String(255), nullable=True))
op.add_column(
"user",
sa.Column("plan_id", sa.String(50), nullable=False, server_default="free"),
)
op.add_column(
"user", sa.Column("stripe_customer_id", sa.String(255), nullable=True)
)
op.add_column(
"user", sa.Column("stripe_subscription_id", sa.String(255), nullable=True)
)
op.create_unique_constraint("uq_user_stripe_customer_id", "user", ["stripe_customer_id"])
op.create_unique_constraint("uq_user_stripe_subscription_id", "user", ["stripe_subscription_id"])
op.create_unique_constraint(
"uq_user_stripe_customer_id", "user", ["stripe_customer_id"]
)
op.create_unique_constraint(
"uq_user_stripe_subscription_id", "user", ["stripe_subscription_id"]
)
def downgrade() -> None:

View file

@ -0,0 +1,207 @@
"""126_seed_admin_user
Revision ID: 126
Revises: 125
Create Date: 2026-04-15
Seeds one admin user on fresh installs (no-op if any user already exists).
Credentials are overridable via env vars:
ADMIN_EMAIL (default: admin@surfsense.local)
ADMIN_PASSWORD (default: Admin@SurfSense1)
Admin is created with:
- is_superuser = TRUE, is_active = TRUE, is_verified = TRUE
- subscription_status = 'active', plan_id = 'pro_yearly'
- monthly_token_limit = 1_000_000, pages_limit = 5000
- A default search space, roles, membership, and prompt defaults
"""
from __future__ import annotations
import os
import uuid
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "126"
down_revision: str | None = "125"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def _hash_password(password: str) -> str:
"""Hash password using argon2-cffi (installed as a fastapi-users dependency)."""
from argon2 import PasswordHasher
ph = PasswordHasher()
return ph.hash(password)
def upgrade() -> None:
conn = op.get_bind()
# Only seed when the database is empty
result = conn.execute(sa.text('SELECT 1 FROM "user" LIMIT 1'))
if result.fetchone() is not None:
return # Users already exist — skip seed
admin_email = os.environ.get("ADMIN_EMAIL", "admin@surfsense.local")
admin_password = os.environ.get("ADMIN_PASSWORD", "Admin@SurfSense1")
if not os.environ.get("ADMIN_PASSWORD"):
print(
"\n⚠️ WARNING: ADMIN_PASSWORD env var not set. "
"Using default password 'Admin@SurfSense1'. "
"Change this immediately after first login!\n"
)
hashed_pw = _hash_password(admin_password)
admin_id = str(uuid.uuid4())
# 1. Insert admin user
conn.execute(
sa.text(
"""
INSERT INTO "user" (
id, email, hashed_password,
is_active, is_superuser, is_verified,
subscription_status, plan_id,
monthly_token_limit, pages_limit, pages_used,
tokens_used_this_month
) VALUES (
:id, :email, :hashed_password,
TRUE, TRUE, TRUE,
'active', 'pro_yearly',
1000000, 5000, 0,
0
)
"""
),
{
"id": admin_id,
"email": admin_email,
"hashed_password": hashed_pw,
},
)
# 2. Insert default search space for admin (only required columns; defaults handle the rest)
search_space_result = conn.execute(
sa.text(
"""
INSERT INTO searchspaces (name, description, citations_enabled, user_id, created_at)
VALUES ('My Search Space', 'Your personal search space', TRUE, :user_id, now())
RETURNING id
"""
),
{"user_id": admin_id},
)
search_space_id = search_space_result.fetchone()[0]
# 3. Insert default roles for the search space
owner_role_result = conn.execute(
sa.text(
"""
INSERT INTO search_space_roles
(name, description, permissions, is_default, is_system_role, search_space_id, created_at)
VALUES (
'Owner', 'Full access to all search space resources and settings',
ARRAY['*'], FALSE, TRUE, :ss_id, now()
)
RETURNING id
"""
),
{"ss_id": search_space_id},
)
owner_role_id = owner_role_result.fetchone()[0]
conn.execute(
sa.text(
"""
INSERT INTO search_space_roles
(name, description, permissions, is_default, is_system_role, search_space_id, created_at)
VALUES
(
'Editor',
'Can create and update content (no delete, role management, or settings access)',
ARRAY[
'documents:create','documents:read','documents:update',
'chats:create','chats:read','chats:update',
'comments:create','comments:read',
'llm_configs:create','llm_configs:read','llm_configs:update',
'podcasts:create','podcasts:read','podcasts:update',
'video_presentations:create','video_presentations:read','video_presentations:update',
'image_generations:create','image_generations:read',
'vision_configs:create','vision_configs:read',
'connectors:create','connectors:read','connectors:update',
'logs:read', 'members:invite'
],
TRUE, TRUE, :ss_id, now()
),
(
'Viewer', 'Read-only access to search space resources',
ARRAY[
'documents:read','chats:read','comments:read',
'llm_configs:read','podcasts:read','video_presentations:read',
'image_generations:read','vision_configs:read','connectors:read','logs:read'
],
FALSE, TRUE, :ss_id, now()
)
"""
),
{"ss_id": search_space_id},
)
# 4. Insert owner membership
conn.execute(
sa.text(
"""
INSERT INTO search_space_memberships
(user_id, search_space_id, role_id, is_owner, joined_at, created_at)
VALUES (:user_id, :ss_id, :role_id, TRUE, now(), now())
"""
),
{"user_id": admin_id, "ss_id": search_space_id, "role_id": owner_role_id},
)
# 5. Insert default prompts (same as migration 114 but just for admin)
conn.execute(
sa.text(
"""
INSERT INTO prompts
(user_id, default_prompt_slug, name, prompt, mode, version, is_public, created_at)
VALUES
(:uid, 'fix-grammar', 'Fix grammar',
'Fix the grammar and spelling in the following text. Return only the corrected text, nothing else.\n\n{selection}',
'transform'::prompt_mode, 1, false, now()),
(:uid, 'make-shorter', 'Make shorter',
'Make the following text more concise while preserving its meaning. Return only the shortened text, nothing else.\n\n{selection}',
'transform'::prompt_mode, 1, false, now()),
(:uid, 'translate', 'Translate',
'Translate the following text to English. If it is already in English, translate it to French. Return only the translation, nothing else.\n\n{selection}',
'transform'::prompt_mode, 1, false, now()),
(:uid, 'rewrite', 'Rewrite',
'Rewrite the following text to improve clarity and readability. Return only the rewritten text, nothing else.\n\n{selection}',
'transform'::prompt_mode, 1, false, now()),
(:uid, 'summarize', 'Summarize',
'Summarize the following text concisely. Return only the summary, nothing else.\n\n{selection}',
'transform'::prompt_mode, 1, false, now()),
(:uid, 'explain', 'Explain',
'Explain the following text in simple terms:\n\n{selection}',
'explore'::prompt_mode, 1, false, now()),
(:uid, 'ask-knowledge-base', 'Ask my knowledge base',
'Search my knowledge base for information related to:\n\n{selection}',
'explore'::prompt_mode, 1, false, now()),
(:uid, 'look-up-web', 'Look up on the web',
'Search the web for information about:\n\n{selection}',
'explore'::prompt_mode, 1, false, now())
ON CONFLICT (user_id, default_prompt_slug) DO NOTHING
"""
),
{"uid": admin_id},
)
def downgrade() -> None:
# Intentional no-op: never delete users on downgrade
pass

View file

@ -0,0 +1,59 @@
"""127_add_subscription_requests_table
Revision ID: 127
Revises: 126
Create Date: 2026-04-15
Adds the subscription_requests table for admin-approval flow when Stripe
is not configured. Users submit a subscription request; superusers can
approve/reject it from the admin panel.
"""
from __future__ import annotations
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "127"
down_revision: str | None = "126"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
conn = op.get_bind()
# Drop any pre-existing enum (e.g. uppercase version from old create_all())
conn.execute(sa.text("DROP TYPE IF EXISTS subscriptionrequeststatus"))
conn.execute(
sa.text(
"CREATE TYPE subscriptionrequeststatus AS ENUM ('pending', 'approved', 'rejected')"
)
)
conn.execute(
sa.text(
"""
CREATE TABLE subscription_requests (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
plan_id VARCHAR(50) NOT NULL,
status subscriptionrequeststatus NOT NULL DEFAULT 'pending',
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
approved_at TIMESTAMPTZ,
approved_by UUID REFERENCES "user"(id)
)
"""
)
)
conn.execute(
sa.text(
"CREATE INDEX ix_subscription_requests_user_id ON subscription_requests (user_id)"
)
)
def downgrade() -> None:
conn = op.get_bind()
conn.execute(sa.text("DROP TABLE IF EXISTS subscription_requests"))
conn.execute(sa.text("DROP TYPE IF EXISTS subscriptionrequeststatus"))

View file

@ -0,0 +1,46 @@
"""128_make_model_config_user_id_nullable
Revision ID: 128
Revises: 127
Create Date: 2026-04-15
Makes user_id nullable on the three model-config tables so that admin-created
(superuser-owned) configurations have user_id = NULL and are visible to all
members of the search space.
Tables affected:
- new_llm_configs
- image_generation_configs
- vision_llm_configs
"""
from __future__ import annotations
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "128"
down_revision: str | None = "127"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.alter_column("new_llm_configs", "user_id", existing_type=sa.UUID(), nullable=True)
op.alter_column("image_generation_configs", "user_id", existing_type=sa.UUID(), nullable=True)
op.alter_column("vision_llm_configs", "user_id", existing_type=sa.UUID(), nullable=True)
def downgrade() -> None:
conn = op.get_bind()
# Null out orphaned rows before re-adding NOT NULL (safety guard)
for table in ("new_llm_configs", "image_generation_configs", "vision_llm_configs"):
# If any rows have user_id=NULL we cannot restore NOT NULL — delete them
conn.execute(sa.text(f'DELETE FROM "{table}" WHERE user_id IS NULL'))
op.alter_column("new_llm_configs", "user_id", existing_type=sa.UUID(), nullable=False)
op.alter_column("image_generation_configs", "user_id", existing_type=sa.UUID(), nullable=False)
op.alter_column("vision_llm_configs", "user_id", existing_type=sa.UUID(), nullable=False)

View file

@ -0,0 +1,69 @@
"""129_make_model_config_search_space_id_nullable
Revision ID: 129
Revises: 128
Create Date: 2026-04-15
Makes search_space_id nullable on the three model-config tables so that
admin-created (superuser-owned) configurations have search_space_id = NULL
and are visible to ALL users across ALL search spaces.
Tables affected:
- new_llm_configs
- image_generation_configs
- vision_llm_configs
"""
from __future__ import annotations
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "129"
down_revision: str | None = "128"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.alter_column(
"new_llm_configs", "search_space_id", existing_type=sa.Integer(), nullable=True
)
op.alter_column(
"image_generation_configs",
"search_space_id",
existing_type=sa.Integer(),
nullable=True,
)
op.alter_column(
"vision_llm_configs",
"search_space_id",
existing_type=sa.Integer(),
nullable=True,
)
def downgrade() -> None:
conn = op.get_bind()
# Delete global configs (search_space_id IS NULL) before restoring NOT NULL
for table in ("new_llm_configs", "image_generation_configs", "vision_llm_configs"):
conn.execute(sa.text(f'DELETE FROM "{table}" WHERE search_space_id IS NULL'))
op.alter_column(
"new_llm_configs", "search_space_id", existing_type=sa.Integer(), nullable=False
)
op.alter_column(
"image_generation_configs",
"search_space_id",
existing_type=sa.Integer(),
nullable=False,
)
op.alter_column(
"vision_llm_configs",
"search_space_id",
existing_type=sa.Integer(),
nullable=False,
)

View file

@ -1197,15 +1197,15 @@ class ImageGenerationConfig(BaseModel, TimestampMixin):
# Relationships
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=True
)
search_space = relationship(
"SearchSpace", back_populates="image_generation_configs"
)
# User who created this config
# User who created this config (NULL for admin-created global configs)
user_id = Column(
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
user = relationship("User", back_populates="image_generation_configs")
@ -1227,12 +1227,13 @@ class VisionLLMConfig(BaseModel, TimestampMixin):
litellm_params = Column(JSON, nullable=True, default={})
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=True
)
search_space = relationship("SearchSpace", back_populates="vision_llm_configs")
# User who created this config (NULL for admin-created global configs)
user_id = Column(
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
user = relationship("User", back_populates="vision_llm_configs")
@ -1535,13 +1536,13 @@ class NewLLMConfig(BaseModel, TimestampMixin):
# === Relationships ===
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=True
)
search_space = relationship("SearchSpace", back_populates="new_llm_configs")
# User who created this config
# User who created this config (NULL for admin-created global configs)
user_id = Column(
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
user = relationship("User", back_populates="new_llm_configs")
@ -1683,6 +1684,56 @@ class PagePurchase(Base, TimestampMixin):
user = relationship("User", back_populates="page_purchases")
class SubscriptionRequestStatus(StrEnum):
PENDING = "pending"
APPROVED = "approved"
REJECTED = "rejected"
class SubscriptionRequest(Base):
"""Tracks subscription upgrade requests when Stripe is not configured (admin-approval flow)."""
__tablename__ = "subscription_requests"
__allow_unmapped__ = True
id = Column(
UUID(as_uuid=True),
primary_key=True,
server_default=text("gen_random_uuid()"),
)
user_id = Column(
UUID(as_uuid=True),
ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
plan_id = Column(String(50), nullable=False)
status = Column(
SQLAlchemyEnum(
SubscriptionRequestStatus,
name="subscriptionrequeststatus",
create_type=False,
values_callable=lambda x: [e.value for e in x],
),
nullable=False,
default=SubscriptionRequestStatus.PENDING,
server_default="pending",
)
created_at = Column(
TIMESTAMP(timezone=True),
nullable=False,
server_default=text("now()"),
)
approved_at = Column(TIMESTAMP(timezone=True), nullable=True)
approved_by = Column(
UUID(as_uuid=True),
ForeignKey("user.id"),
nullable=True,
)
user = relationship("User", foreign_keys=[user_id], back_populates="subscription_requests")
class SearchSpaceRole(BaseModel, TimestampMixin):
"""
Custom roles that can be defined per search space.
@ -1953,6 +2004,12 @@ if config.AUTH_TYPE == "GOOGLE":
back_populates="user",
cascade="all, delete-orphan",
)
subscription_requests = relationship(
"SubscriptionRequest",
foreign_keys="SubscriptionRequest.user_id",
back_populates="user",
cascade="all, delete-orphan",
)
# Page usage tracking for ETL services
pages_limit = Column(
@ -1968,7 +2025,7 @@ if config.AUTH_TYPE == "GOOGLE":
tokens_used_this_month = Column(Integer, nullable=False, default=0, server_default="0")
token_reset_date = Column(Date, nullable=True)
subscription_status = Column(
SQLAlchemyEnum(SubscriptionStatus, name="subscriptionstatus", create_type=True),
SQLAlchemyEnum(SubscriptionStatus, name="subscriptionstatus", create_type=True, values_callable=lambda x: [e.value for e in x]),
nullable=False,
default=SubscriptionStatus.FREE,
server_default="free",
@ -2082,6 +2139,12 @@ else:
back_populates="user",
cascade="all, delete-orphan",
)
subscription_requests = relationship(
"SubscriptionRequest",
foreign_keys="SubscriptionRequest.user_id",
back_populates="user",
cascade="all, delete-orphan",
)
# Page usage tracking for ETL services
pages_limit = Column(
@ -2097,7 +2160,7 @@ else:
tokens_used_this_month = Column(Integer, nullable=False, default=0, server_default="0")
token_reset_date = Column(Date, nullable=True)
subscription_status = Column(
SQLAlchemyEnum(SubscriptionStatus, name="subscriptionstatus", create_type=True),
SQLAlchemyEnum(SubscriptionStatus, name="subscriptionstatus", create_type=True, values_callable=lambda x: [e.value for e in x]),
nullable=False,
default=SubscriptionStatus.FREE,
server_default="free",

View file

@ -48,6 +48,7 @@ from .sandbox_routes import router as sandbox_router
from .search_source_connectors_routes import router as search_source_connectors_router
from .search_spaces_routes import router as search_spaces_router
from .slack_add_connector_route import router as slack_add_connector_router
from .admin_routes import router as admin_router
from .stripe_routes import router as stripe_router
from .surfsense_docs_routes import router as surfsense_docs_router
from .teams_add_connector_route import router as teams_add_connector_router
@ -100,6 +101,7 @@ router.include_router(notifications_router) # Notifications with Zero sync
router.include_router(composio_router) # Composio OAuth and toolkit management
router.include_router(public_chat_router) # Public chat sharing and cloning
router.include_router(incentive_tasks_router) # Incentive tasks for earning free pages
router.include_router(admin_router) # Superuser admin operations
router.include_router(stripe_router) # Stripe checkout for additional page packs
router.include_router(youtube_router) # YouTube playlist resolution
router.include_router(prompts_router)

View file

@ -0,0 +1,212 @@
"""Admin routes — superuser-only operations."""
from __future__ import annotations
import uuid
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import (
SubscriptionRequest,
SubscriptionRequestStatus,
SubscriptionStatus,
User,
get_async_session,
)
from app.users import current_superuser
router = APIRouter(prefix="/admin", tags=["admin"])
# ---------------------------------------------------------------------------
# Response schemas
# ---------------------------------------------------------------------------
class SubscriptionRequestItem(BaseModel):
id: uuid.UUID
user_id: uuid.UUID
user_email: str
plan_id: str
status: str
created_at: datetime
approved_at: datetime | None = None
approved_by: uuid.UUID | None = None
model_config = {"from_attributes": True}
# ---------------------------------------------------------------------------
# List pending subscription requests
# ---------------------------------------------------------------------------
@router.get(
"/subscription-requests",
response_model=list[SubscriptionRequestItem],
)
async def list_subscription_requests(
admin: User = Depends(current_superuser),
db_session: AsyncSession = Depends(get_async_session),
) -> list[SubscriptionRequestItem]:
"""Return all pending subscription requests."""
result = await db_session.execute(
select(SubscriptionRequest)
.where(SubscriptionRequest.status == SubscriptionRequestStatus.PENDING)
.order_by(SubscriptionRequest.created_at.asc())
)
requests = result.scalars().all()
# Collect user IDs and batch-load to avoid N+1
user_ids = [req.user_id for req in requests]
email_map: dict[uuid.UUID, str] = {}
if user_ids:
user_rows = await db_session.execute(select(User).where(User.id.in_(user_ids)))
for u in user_rows.scalars():
email_map[u.id] = u.email
items: list[SubscriptionRequestItem] = [
SubscriptionRequestItem(
id=req.id,
user_id=req.user_id,
user_email=email_map.get(req.user_id, "<deleted>"),
plan_id=req.plan_id,
status=req.status.value,
created_at=req.created_at,
approved_at=req.approved_at,
approved_by=req.approved_by,
)
for req in requests
]
return items
# ---------------------------------------------------------------------------
# Approve a subscription request
# ---------------------------------------------------------------------------
@router.post(
"/subscription-requests/{request_id}/approve",
response_model=SubscriptionRequestItem,
)
async def approve_subscription_request(
request_id: uuid.UUID,
admin: User = Depends(current_superuser),
db_session: AsyncSession = Depends(get_async_session),
) -> SubscriptionRequestItem:
"""Approve a pending subscription request and activate the user's subscription."""
result = await db_session.execute(
select(SubscriptionRequest)
.where(SubscriptionRequest.id == request_id)
.with_for_update()
)
req = result.scalar_one_or_none()
if req is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Subscription request not found.",
)
if req.status != SubscriptionRequestStatus.PENDING:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Request is already {req.status.value}.",
)
user_result = await db_session.execute(
select(User).where(User.id == req.user_id).with_for_update()
)
req_user = user_result.scalar_one_or_none()
if req_user is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found."
)
# Activate subscription
plan_limits = config.PLAN_LIMITS.get(req.plan_id, config.PLAN_LIMITS["free"])
req_user.subscription_status = SubscriptionStatus.ACTIVE
req_user.plan_id = req.plan_id
req_user.monthly_token_limit = plan_limits["monthly_token_limit"]
req_user.pages_limit = max(req_user.pages_used or 0, plan_limits["pages_limit"])
req_user.tokens_used_this_month = 0
req_user.token_reset_date = datetime.now(UTC).date()
# Mark request approved
now = datetime.now(UTC)
req.status = SubscriptionRequestStatus.APPROVED
req.approved_at = now
req.approved_by = admin.id
await db_session.commit()
await db_session.refresh(req)
user_result2 = await db_session.execute(select(User).where(User.id == req.user_id))
req_user2 = user_result2.scalar_one_or_none()
email = req_user2.email if req_user2 else "<deleted>"
return SubscriptionRequestItem(
id=req.id,
user_id=req.user_id,
user_email=email,
plan_id=req.plan_id,
status=req.status.value,
created_at=req.created_at,
approved_at=req.approved_at,
approved_by=req.approved_by,
)
# ---------------------------------------------------------------------------
# Reject a subscription request
# ---------------------------------------------------------------------------
@router.post(
"/subscription-requests/{request_id}/reject",
response_model=SubscriptionRequestItem,
)
async def reject_subscription_request(
request_id: uuid.UUID,
admin: User = Depends(current_superuser),
db_session: AsyncSession = Depends(get_async_session),
) -> SubscriptionRequestItem:
"""Reject a pending subscription request."""
result = await db_session.execute(
select(SubscriptionRequest).where(SubscriptionRequest.id == request_id)
)
req = result.scalar_one_or_none()
if req is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Subscription request not found.",
)
if req.status != SubscriptionRequestStatus.PENDING:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Request is already {req.status.value}.",
)
req.status = SubscriptionRequestStatus.REJECTED
await db_session.commit()
await db_session.refresh(req)
user_result = await db_session.execute(select(User).where(User.id == req.user_id))
req_user = user_result.scalar_one_or_none()
email = req_user.email if req_user else "<deleted>"
return SubscriptionRequestItem(
id=req.id,
user_id=req.user_id,
user_email=email,
plan_id=req.plan_id,
status=req.status.value,
created_at=req.created_at,
approved_at=req.approved_at,
approved_by=req.approved_by,
)

View file

@ -73,6 +73,24 @@ async def create_documents(
"You don't have permission to create documents in this search space",
)
# Page quota pre-check for connector documents
from app.services.page_limit_service import (
PageLimitExceededError,
PageLimitService,
)
estimated_pages = len(request.content) # 1 page per document/URL
try:
page_service = PageLimitService(session)
await page_service.check_page_limit(str(user.id), estimated_pages)
except PageLimitExceededError as e:
raise HTTPException(
status_code=402,
detail=f"Page quota exceeded ({e.pages_used}/{e.pages_limit}). "
f"This request requires ~{estimated_pages} pages. "
f"Upgrade your plan for more pages.",
) from e
if request.document_type == DocumentType.EXTENSION:
from app.tasks.celery_tasks.document_tasks import (
process_extension_document_task,
@ -169,6 +187,30 @@ async def create_documents_file_upload(
f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
)
# Page quota pre-check
from app.services.page_limit_service import (
PageLimitExceededError,
PageLimitService,
)
total_estimated_pages = sum(
PageLimitService.estimate_pages_from_metadata(
file.filename or "", file.size or 0
)
for file in files
)
try:
page_service = PageLimitService(session)
await page_service.check_page_limit(str(user.id), total_estimated_pages)
except PageLimitExceededError as e:
raise HTTPException(
status_code=402,
detail=f"Page quota exceeded ({e.pages_used}/{e.pages_limit}). "
f"This upload requires ~{total_estimated_pages} pages. "
f"Upgrade your plan for more pages.",
) from e
# ===== Read all files concurrently to avoid blocking the event loop =====
async def _read_and_save(file: UploadFile) -> tuple[str, str, int]:
"""Read upload content and write to temp file off the event loop."""

View file

@ -30,6 +30,7 @@ from app.db import (
from app.schemas import (
GlobalImageGenConfigRead,
ImageGenerationConfigCreate,
ImageGenerationConfigPublic,
ImageGenerationConfigRead,
ImageGenerationConfigUpdate,
ImageGenerationCreate,
@ -41,7 +42,7 @@ from app.services.image_gen_router_service import (
ImageGenRouterService,
is_image_gen_auto_mode,
)
from app.users import current_active_user
from app.users import current_active_user, current_superuser
from app.utils.rbac import check_permission
from app.utils.signed_image_urls import verify_image_token
@ -261,19 +262,11 @@ async def get_global_image_gen_configs(
async def create_image_gen_config(
config_data: ImageGenerationConfigCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
user: User = Depends(current_superuser),
):
"""Create a new image generation config for a search space."""
"""Create a new image generation config for a search space. Superuser only."""
try:
await check_permission(
session,
user,
config_data.search_space_id,
Permission.IMAGE_GENERATIONS_CREATE.value,
"You don't have permission to create image generation configs in this search space",
)
db_config = ImageGenerationConfig(**config_data.model_dump(), user_id=user.id)
db_config = ImageGenerationConfig(**config_data.model_dump(), user_id=None)
session.add(db_config)
await session.commit()
await session.refresh(db_config)
@ -289,7 +282,9 @@ async def create_image_gen_config(
) from e
@router.get("/image-generation-configs", response_model=list[ImageGenerationConfigRead])
@router.get(
"/image-generation-configs", response_model=list[ImageGenerationConfigPublic]
)
async def list_image_gen_configs(
search_space_id: int,
skip: int = 0,
@ -297,7 +292,7 @@ async def list_image_gen_configs(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""List image generation configs for a search space."""
"""List image generation configs for a search space (includes global admin configs)."""
try:
await check_permission(
session,
@ -309,7 +304,10 @@ async def list_image_gen_configs(
result = await session.execute(
select(ImageGenerationConfig)
.filter(ImageGenerationConfig.search_space_id == search_space_id)
.filter(
(ImageGenerationConfig.search_space_id == search_space_id)
| (ImageGenerationConfig.search_space_id == None) # noqa: E711
)
.order_by(ImageGenerationConfig.created_at.desc())
.offset(skip)
.limit(limit)
@ -367,25 +365,20 @@ async def update_image_gen_config(
config_id: int,
update_data: ImageGenerationConfigUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
user: User = Depends(current_superuser),
):
"""Update an existing image generation config."""
"""Update an existing image generation config. Superuser only."""
try:
result = await session.execute(
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
select(ImageGenerationConfig).filter(
ImageGenerationConfig.id == config_id,
ImageGenerationConfig.user_id.is_(None),
)
)
db_config = result.scalars().first()
if not db_config:
raise HTTPException(status_code=404, detail="Config not found")
await check_permission(
session,
user,
db_config.search_space_id,
Permission.IMAGE_GENERATIONS_CREATE.value,
"You don't have permission to update image generation configs in this search space",
)
for key, value in update_data.model_dump(exclude_unset=True).items():
setattr(db_config, key, value)
@ -407,25 +400,20 @@ async def update_image_gen_config(
async def delete_image_gen_config(
config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
user: User = Depends(current_superuser),
):
"""Delete an image generation config."""
"""Delete an image generation config. Superuser only."""
try:
result = await session.execute(
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
select(ImageGenerationConfig).filter(
ImageGenerationConfig.id == config_id,
ImageGenerationConfig.user_id.is_(None),
)
)
db_config = result.scalars().first()
if not db_config:
raise HTTPException(status_code=404, detail="Config not found")
await check_permission(
session,
user,
db_config.search_space_id,
Permission.IMAGE_GENERATIONS_DELETE.value,
"You don't have permission to delete image generation configs in this search space",
)
await session.delete(db_config)
await session.commit()
return {

View file

@ -25,11 +25,12 @@ from app.schemas import (
DefaultSystemInstructionsResponse,
GlobalNewLLMConfigRead,
NewLLMConfigCreate,
NewLLMConfigPublic,
NewLLMConfigRead,
NewLLMConfigUpdate,
)
from app.services.llm_service import validate_llm_config
from app.users import current_active_user
from app.users import current_active_user, current_superuser
from app.utils.rbac import check_permission
router = APIRouter()
@ -117,22 +118,13 @@ async def get_global_new_llm_configs(
async def create_new_llm_config(
config_data: NewLLMConfigCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
user: User = Depends(current_superuser),
):
"""
Create a new NewLLMConfig for a search space.
Requires LLM_CONFIGS_CREATE permission.
Superuser only configs are shared with all search space members.
"""
try:
# Verify user has permission
await check_permission(
session,
user,
config_data.search_space_id,
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to create LLM configurations in this search space",
)
# Validate the LLM configuration by making a test API call
is_valid, error_message = await validate_llm_config(
provider=config_data.provider.value,
@ -149,8 +141,8 @@ async def create_new_llm_config(
detail=f"Invalid LLM configuration: {error_message}",
)
# Create the config with user association
db_config = NewLLMConfig(**config_data.model_dump(), user_id=user.id)
# Create the config as admin-owned (user_id=None means shared with all space members)
db_config = NewLLMConfig(**config_data.model_dump(), user_id=None)
session.add(db_config)
await session.commit()
await session.refresh(db_config)
@ -167,7 +159,7 @@ async def create_new_llm_config(
) from e
@router.get("/new-llm-configs", response_model=list[NewLLMConfigRead])
@router.get("/new-llm-configs", response_model=list[NewLLMConfigPublic])
async def list_new_llm_configs(
search_space_id: int,
skip: int = 0,
@ -176,11 +168,11 @@ async def list_new_llm_configs(
user: User = Depends(current_active_user),
):
"""
Get all NewLLMConfigs for a search space.
Get all NewLLMConfigs for a search space (includes global admin configs).
Requires LLM_CONFIGS_READ permission.
"""
try:
# Verify user has permission
# Verify user has permission for their space
await check_permission(
session,
user,
@ -191,7 +183,10 @@ async def list_new_llm_configs(
result = await session.execute(
select(NewLLMConfig)
.filter(NewLLMConfig.search_space_id == search_space_id)
.filter(
(NewLLMConfig.search_space_id == search_space_id)
| (NewLLMConfig.search_space_id == None) # noqa: E711
)
.order_by(NewLLMConfig.created_at.desc())
.offset(skip)
.limit(limit)
@ -268,30 +263,23 @@ async def update_new_llm_config(
config_id: int,
update_data: NewLLMConfigUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
user: User = Depends(current_superuser),
):
"""
Update an existing NewLLMConfig.
Requires LLM_CONFIGS_UPDATE permission.
Superuser only.
"""
try:
result = await session.execute(
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
select(NewLLMConfig).filter(
NewLLMConfig.id == config_id, NewLLMConfig.user_id.is_(None)
)
)
config = result.scalars().first()
if not config:
raise HTTPException(status_code=404, detail="Configuration not found")
# Verify user has permission
await check_permission(
session,
user,
config.search_space_id,
Permission.LLM_CONFIGS_UPDATE.value,
"You don't have permission to update LLM configurations in this search space",
)
update_dict = update_data.model_dump(exclude_unset=True)
# If updating LLM settings, validate them
@ -360,30 +348,23 @@ async def update_new_llm_config(
async def delete_new_llm_config(
config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
user: User = Depends(current_superuser),
):
"""
Delete a NewLLMConfig.
Requires LLM_CONFIGS_DELETE permission.
Superuser only.
"""
try:
result = await session.execute(
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
select(NewLLMConfig).filter(
NewLLMConfig.id == config_id, NewLLMConfig.user_id.is_(None)
)
)
config = result.scalars().first()
if not config:
raise HTTPException(status_code=404, detail="Configuration not found")
# Verify user has permission
await check_permission(
session,
user,
config.search_space_id,
Permission.LLM_CONFIGS_DELETE.value,
"You don't have permission to delete LLM configurations in this search space",
)
await session.delete(config)
await session.commit()

View file

@ -4,7 +4,8 @@ from __future__ import annotations
import logging
import uuid
from datetime import UTC, datetime
from collections import defaultdict
from datetime import UTC, datetime, timedelta
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Request, status
@ -13,7 +14,15 @@ from sqlalchemy.ext.asyncio import AsyncSession
from stripe import SignatureVerificationError, StripeClient, StripeError
from app.config import config
from app.db import PagePurchase, PagePurchaseStatus, SubscriptionStatus, User, get_async_session
from app.db import (
PagePurchase,
PagePurchaseStatus,
SubscriptionRequest,
SubscriptionRequestStatus,
SubscriptionStatus,
User,
get_async_session,
)
from app.schemas.stripe import (
CreateCheckoutSessionRequest,
CreateCheckoutSessionResponse,
@ -30,6 +39,28 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/stripe", tags=["stripe"])
# ---------------------------------------------------------------------------
# Simple in-memory rate limiter for verify-checkout-session (20 calls/60 s)
# Not persistent across workers — acceptable for the low-risk, low-volume
# nature of this endpoint.
# ---------------------------------------------------------------------------
_VERIFY_SESSION_WINDOW_SECS = 60
_VERIFY_SESSION_MAX_CALLS = 20
_verify_session_calls: dict[str, list[float]] = defaultdict(list)
def _check_verify_session_rate_limit(user_id: str) -> None:
now = datetime.now(UTC).timestamp()
cutoff = now - _VERIFY_SESSION_WINDOW_SECS
calls = [t for t in _verify_session_calls[user_id] if t > cutoff]
if len(calls) >= _VERIFY_SESSION_MAX_CALLS:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Too many requests. Try again later.",
)
calls.append(now)
_verify_session_calls[user_id] = calls
def get_stripe_client() -> StripeClient:
"""Return a configured Stripe client or raise if Stripe is disabled."""
@ -145,7 +176,10 @@ async def _get_or_create_stripe_customer(
try:
customer = stripe_client.v1.customers.create(
params={"email": locked_user.email, "metadata": {"user_id": str(locked_user.id)}}
params={
"email": locked_user.email,
"metadata": {"user_id": str(locked_user.id)},
}
)
except StripeError as exc:
logger.exception("Failed to create Stripe customer for user %s", locked_user.id)
@ -288,6 +322,7 @@ async def _fulfill_completed_purchase(
# Subscription event helpers
# ---------------------------------------------------------------------------
async def _get_user_by_stripe_customer_id(
db_session: AsyncSession, customer_id: str
) -> User | None:
@ -344,16 +379,34 @@ async def _handle_subscription_event(
subscription_id,
price_id,
)
except Exception: # noqa: BLE001
except Exception:
logger.warning("Could not parse plan from subscription %s", subscription_id)
if not customer_id:
logger.error("Subscription event missing customer ID for subscription %s", subscription_id)
logger.error(
"Subscription event missing customer ID for subscription %s",
subscription_id,
)
return StripeWebhookResponse()
# Safety: never silently downgrade an active subscription to "free" due to
# an unrecognized price ID. Return early without modifying the user.
if (
plan_id == "free"
and str(getattr(subscription, "status", "")).lower() == "active"
):
logger.error(
"Subscription %s is active but price ID is unrecognized — skipping update to avoid downgrade",
subscription_id,
)
return StripeWebhookResponse()
user = await _get_user_by_stripe_customer_id(db_session, customer_id)
if user is None:
logger.warning("No user found for Stripe customer %s; skipping subscription event", customer_id)
logger.warning(
"No user found for Stripe customer %s; skipping subscription event",
customer_id,
)
return StripeWebhookResponse()
# Map Stripe status → SubscriptionStatus enum
@ -398,9 +451,11 @@ async def _handle_subscription_event(
limits = config.PLAN_LIMITS.get(plan_id, config.PLAN_LIMITS["free"])
user.monthly_token_limit = limits["monthly_token_limit"]
# Upgrade pages_limit on activation
# Upgrade pages_limit on activation; reset token counter date
if new_status == SubscriptionStatus.ACTIVE:
user.pages_limit = max(user.pages_used, limits["pages_limit"])
if user.token_reset_date is None:
user.token_reset_date = datetime.now(UTC).date()
# Downgrade pages_limit when canceling
if new_status == SubscriptionStatus.CANCELED:
@ -430,18 +485,25 @@ async def _handle_invoice_payment_succeeded(
# Reset tokens on subscription renewals and initial subscription creation
if billing_reason not in {"subscription_cycle", "subscription_create"}:
logger.info("invoice.payment_succeeded billing_reason=%s; not resetting tokens", billing_reason)
logger.info(
"invoice.payment_succeeded billing_reason=%s; not resetting tokens",
billing_reason,
)
return StripeWebhookResponse()
user = await _get_user_by_stripe_customer_id(db_session, customer_id)
if user is None:
logger.warning("No user found for Stripe customer %s; skipping token reset", customer_id)
logger.warning(
"No user found for Stripe customer %s; skipping token reset", customer_id
)
return StripeWebhookResponse()
user.tokens_used_this_month = 0
user.token_reset_date = datetime.now(UTC).date()
logger.info("Reset tokens_used_this_month for user %s on subscription renewal", user.id)
logger.info(
"Reset tokens_used_this_month for user %s on subscription renewal", user.id
)
await db_session.commit()
return StripeWebhookResponse()
@ -456,7 +518,10 @@ async def _handle_invoice_payment_failed(
user = await _get_user_by_stripe_customer_id(db_session, customer_id)
if user is None:
logger.warning("No user found for Stripe customer %s; skipping past_due update", customer_id)
logger.warning(
"No user found for Stripe customer %s; skipping past_due update",
customer_id,
)
return StripeWebhookResponse()
if user.subscription_status == SubscriptionStatus.ACTIVE:
@ -464,7 +529,11 @@ async def _handle_invoice_payment_failed(
logger.info("Set subscription to PAST_DUE for user %s", user.id)
await db_session.commit()
else:
logger.info("invoice.payment_failed for user %s already in status %s; no change", user.id, user.subscription_status)
logger.info(
"invoice.payment_failed for user %s already in status %s; no change",
user.id,
user.subscription_status,
)
return StripeWebhookResponse()
@ -477,26 +546,43 @@ async def _activate_subscription_from_checkout(
The full subscription lifecycle will also be handled by customer.subscription.created,
but we activate immediately here so the user sees Pro access right after checkout.
"""
customer_id = _normalize_optional_string(getattr(checkout_session, "customer", None))
subscription_id = _normalize_optional_string(getattr(checkout_session, "subscription", None))
customer_id = _normalize_optional_string(
getattr(checkout_session, "customer", None)
)
subscription_id = _normalize_optional_string(
getattr(checkout_session, "subscription", None)
)
metadata = _get_metadata(checkout_session)
plan_id_str = metadata.get("plan_id", "")
if not customer_id:
logger.error("Subscription checkout session missing customer ID: %s", getattr(checkout_session, "id", ""))
logger.error(
"Subscription checkout session missing customer ID: %s",
getattr(checkout_session, "id", ""),
)
return StripeWebhookResponse()
user = await _get_user_by_stripe_customer_id(db_session, customer_id)
if user is None:
logger.warning("No user found for Stripe customer %s; skipping subscription activation", customer_id)
logger.warning(
"No user found for Stripe customer %s; skipping subscription activation",
customer_id,
)
return StripeWebhookResponse()
# Idempotency: already activated
if user.subscription_status == SubscriptionStatus.ACTIVE and user.stripe_subscription_id == subscription_id:
logger.info("Subscription already active for user %s; skipping activation", user.id)
if (
user.subscription_status == SubscriptionStatus.ACTIVE
and user.stripe_subscription_id == subscription_id
):
logger.info(
"Subscription already active for user %s; skipping activation", user.id
)
return StripeWebhookResponse()
plan_id = plan_id_str if plan_id_str in {"pro_monthly", "pro_yearly"} else "pro_monthly"
plan_id = (
plan_id_str if plan_id_str in {"pro_monthly", "pro_yearly"} else "pro_monthly"
)
limits = config.PLAN_LIMITS.get(plan_id, config.PLAN_LIMITS["pro_monthly"])
user.subscription_status = SubscriptionStatus.ACTIVE
@ -512,11 +598,20 @@ async def _activate_subscription_from_checkout(
try:
stripe_client = get_stripe_client()
sub_obj = stripe_client.v1.subscriptions.retrieve(subscription_id)
user.subscription_current_period_end = _period_end_from_subscription(sub_obj)
except Exception: # noqa: BLE001
logger.warning("Could not retrieve subscription %s for period_end", subscription_id)
user.subscription_current_period_end = _period_end_from_subscription(
sub_obj
)
except Exception:
logger.warning(
"Could not retrieve subscription %s for period_end", subscription_id
)
logger.info("Activated subscription for user %s: plan=%s subscription=%s", user.id, plan_id, subscription_id)
logger.info(
"Activated subscription for user %s: plan=%s subscription=%s",
user.id,
plan_id,
subscription_id,
)
await db_session.commit()
return StripeWebhookResponse()
@ -601,6 +696,47 @@ async def create_subscription_checkout(
db_session: AsyncSession = Depends(get_async_session),
) -> CreateSubscriptionCheckoutResponse:
"""Create a Stripe Checkout Session for a recurring subscription."""
# Admin-approval mode: when Stripe is not configured, queue a manual request
if not config.STRIPE_SECRET_KEY:
if user.subscription_status == SubscriptionStatus.ACTIVE:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="You already have an active subscription.",
)
existing = await db_session.execute(
select(SubscriptionRequest)
.where(SubscriptionRequest.user_id == user.id)
.where(SubscriptionRequest.status == SubscriptionRequestStatus.PENDING)
)
if existing.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="You already have a pending subscription request.",
)
cooldown_cutoff = datetime.now(UTC) - timedelta(hours=24)
recently_rejected = await db_session.execute(
select(SubscriptionRequest)
.where(SubscriptionRequest.user_id == user.id)
.where(SubscriptionRequest.status == SubscriptionRequestStatus.REJECTED)
.where(SubscriptionRequest.created_at >= cooldown_cutoff)
)
if recently_rejected.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Your previous request was rejected. Please wait 24 hours before resubmitting.",
)
req = SubscriptionRequest(user_id=user.id, plan_id=body.plan_id.value)
db_session.add(req)
await db_session.commit()
logger.info(
"Admin-approval subscription request created for user %s (plan=%s)",
user.id,
body.plan_id.value,
)
return CreateSubscriptionCheckoutResponse(
checkout_url="", admin_approval_mode=True
)
stripe_client = get_stripe_client()
price_id = _get_price_id_for_plan(body.plan_id)
success_url, cancel_url = _get_subscription_urls()
@ -653,6 +789,7 @@ async def verify_checkout_session(
user: User = Depends(current_active_user),
) -> dict:
"""Verify a Stripe Checkout Session belongs to the user and is paid."""
_check_verify_session_rate_limit(str(user.id))
stripe_client = get_stripe_client()
try:
session = stripe_client.v1.checkout.sessions.retrieve(session_id)
@ -743,7 +880,9 @@ async def stripe_webhook(
return StripeWebhookResponse()
if session_mode == "subscription":
return await _activate_subscription_from_checkout(db_session, checkout_session)
return await _activate_subscription_from_checkout(
db_session, checkout_session
)
return await _fulfill_completed_purchase(db_session, checkout_session)

View file

@ -15,11 +15,12 @@ from app.db import (
from app.schemas import (
GlobalVisionLLMConfigRead,
VisionLLMConfigCreate,
VisionLLMConfigPublic,
VisionLLMConfigRead,
VisionLLMConfigUpdate,
)
from app.services.vision_model_list_service import get_vision_model_list
from app.users import current_active_user
from app.users import current_active_user, current_superuser
from app.utils.rbac import check_permission
router = APIRouter()
@ -118,18 +119,11 @@ async def get_global_vision_llm_configs(
async def create_vision_llm_config(
config_data: VisionLLMConfigCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
user: User = Depends(current_superuser),
):
"""Create a new vision LLM config. Superuser only."""
try:
await check_permission(
session,
user,
config_data.search_space_id,
Permission.VISION_CONFIGS_CREATE.value,
"You don't have permission to create vision LLM configs in this search space",
)
db_config = VisionLLMConfig(**config_data.model_dump(), user_id=user.id)
db_config = VisionLLMConfig(**config_data.model_dump(), user_id=None)
session.add(db_config)
await session.commit()
await session.refresh(db_config)
@ -145,7 +139,7 @@ async def create_vision_llm_config(
) from e
@router.get("/vision-llm-configs", response_model=list[VisionLLMConfigRead])
@router.get("/vision-llm-configs", response_model=list[VisionLLMConfigPublic])
async def list_vision_llm_configs(
search_space_id: int,
skip: int = 0,
@ -164,7 +158,10 @@ async def list_vision_llm_configs(
result = await session.execute(
select(VisionLLMConfig)
.filter(VisionLLMConfig.search_space_id == search_space_id)
.filter(
(VisionLLMConfig.search_space_id == search_space_id)
| (VisionLLMConfig.search_space_id == None) # noqa: E711
)
.order_by(VisionLLMConfig.created_at.desc())
.offset(skip)
.limit(limit)
@ -217,24 +214,19 @@ async def update_vision_llm_config(
config_id: int,
update_data: VisionLLMConfigUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
user: User = Depends(current_superuser),
):
"""Update an existing vision LLM config. Superuser only."""
try:
result = await session.execute(
select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id)
select(VisionLLMConfig).filter(
VisionLLMConfig.id == config_id, VisionLLMConfig.user_id.is_(None)
)
)
db_config = result.scalars().first()
if not db_config:
raise HTTPException(status_code=404, detail="Config not found")
await check_permission(
session,
user,
db_config.search_space_id,
Permission.VISION_CONFIGS_CREATE.value,
"You don't have permission to update vision LLM configs in this search space",
)
for key, value in update_data.model_dump(exclude_unset=True).items():
setattr(db_config, key, value)
@ -256,24 +248,19 @@ async def update_vision_llm_config(
async def delete_vision_llm_config(
config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
user: User = Depends(current_superuser),
):
"""Delete a vision LLM config. Superuser only."""
try:
result = await session.execute(
select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id)
select(VisionLLMConfig).filter(
VisionLLMConfig.id == config_id, VisionLLMConfig.user_id.is_(None)
)
)
db_config = result.scalars().first()
if not db_config:
raise HTTPException(status_code=404, detail="Config not found")
await check_permission(
session,
user,
db_config.search_space_id,
Permission.VISION_CONFIGS_DELETE.value,
"You don't have permission to delete vision LLM configs in this search space",
)
await session.delete(db_config)
await session.commit()
return {

View file

@ -55,8 +55,9 @@ class ImageGenerationConfigBase(BaseModel):
class ImageGenerationConfigCreate(ImageGenerationConfigBase):
"""Schema for creating a new ImageGenerationConfig."""
search_space_id: int = Field(
..., description="Search space ID to associate the config with"
search_space_id: int | None = Field(
None,
description="Search space ID. None = global admin config visible to all spaces",
)
@ -79,8 +80,8 @@ class ImageGenerationConfigRead(ImageGenerationConfigBase):
id: int
created_at: datetime
search_space_id: int
user_id: uuid.UUID
search_space_id: int | None = None
user_id: uuid.UUID | None = None
model_config = ConfigDict(from_attributes=True)
@ -98,8 +99,8 @@ class ImageGenerationConfigPublic(BaseModel):
api_version: str | None = None
litellm_params: dict[str, Any] | None = None
created_at: datetime
search_space_id: int
user_id: uuid.UUID
search_space_id: int | None = None
user_id: uuid.UUID | None = None
model_config = ConfigDict(from_attributes=True)

View file

@ -60,8 +60,9 @@ class NewLLMConfigBase(BaseModel):
class NewLLMConfigCreate(NewLLMConfigBase):
"""Schema for creating a new NewLLMConfig."""
search_space_id: int = Field(
..., description="Search space ID to associate the config with"
search_space_id: int | None = Field(
None,
description="Search space ID. None = global admin config visible to all spaces",
)
@ -90,8 +91,8 @@ class NewLLMConfigRead(NewLLMConfigBase):
id: int
created_at: datetime
search_space_id: int
user_id: uuid.UUID
search_space_id: int | None = None
user_id: uuid.UUID | None = None
model_config = ConfigDict(from_attributes=True)
@ -119,8 +120,8 @@ class NewLLMConfigPublic(BaseModel):
citations_enabled: bool
created_at: datetime
search_space_id: int
user_id: uuid.UUID
search_space_id: int | None = None
user_id: uuid.UUID | None = None
model_config = ConfigDict(from_attributes=True)

View file

@ -33,6 +33,7 @@ class CreateSubscriptionCheckoutResponse(BaseModel):
"""Response containing the Stripe-hosted subscription checkout URL."""
checkout_url: str
admin_approval_mode: bool = False
class CreateCheckoutSessionResponse(BaseModel):

View file

@ -6,6 +6,10 @@ from fastapi_users import schemas
class UserRead(schemas.BaseUser[uuid.UUID]):
pages_limit: int
pages_used: int
monthly_token_limit: int
tokens_used_this_month: int
plan_id: str
subscription_status: str
display_name: str | None = None
avatar_url: str | None = None

View file

@ -20,7 +20,10 @@ class VisionLLMConfigBase(BaseModel):
class VisionLLMConfigCreate(VisionLLMConfigBase):
search_space_id: int = Field(...)
search_space_id: int | None = Field(
None,
description="Search space ID. None = global admin config visible to all spaces",
)
class VisionLLMConfigUpdate(BaseModel):
@ -38,8 +41,8 @@ class VisionLLMConfigUpdate(BaseModel):
class VisionLLMConfigRead(VisionLLMConfigBase):
id: int
created_at: datetime
search_space_id: int
user_id: uuid.UUID
search_space_id: int | None = None
user_id: uuid.UUID | None = None
model_config = ConfigDict(from_attributes=True)
@ -55,8 +58,8 @@ class VisionLLMConfigPublic(BaseModel):
api_version: str | None = None
litellm_params: dict[str, Any] | None = None
created_at: datetime
search_space_id: int
user_id: uuid.UUID
search_space_id: int | None = None
user_id: uuid.UUID | None = None
model_config = ConfigDict(from_attributes=True)

View file

@ -51,16 +51,25 @@ class PageLimitService:
"""
from app.db import User
# Get user's current page usage
# Get user's current page usage and subscription status
result = await self.session.execute(
select(User.pages_used, User.pages_limit).where(User.id == user_id)
select(User.pages_used, User.pages_limit, User.subscription_status).where(
User.id == user_id
)
)
row = result.first()
if not row:
raise ValueError(f"User with ID {user_id} not found")
pages_used, pages_limit = row
pages_used, pages_limit, sub_status = row
# PAST_DUE: enforce free-tier page limit to prevent usage without payment
if str(sub_status).lower() == "past_due":
from app.config import config as app_config # avoid circular import
free_limit = app_config.PLAN_LIMITS.get("free", {}).get("pages_limit", 500)
pages_limit = min(pages_limit, free_limit)
# Check if adding estimated pages would exceed limit
if pages_used + estimated_pages > pages_limit:

View file

@ -97,6 +97,15 @@ class TokenQuotaService:
tokens_used = user.tokens_used_this_month or 0
token_limit = user.monthly_token_limit or 0
# PAST_DUE: enforce free-tier token limit to prevent usage without payment
if str(getattr(user, "subscription_status", "")).lower() == "past_due":
from app.config import config as app_config # avoid circular import
free_limit = app_config.PLAN_LIMITS.get("free", {}).get(
"monthly_token_limit", 50000
)
token_limit = min(token_limit, free_limit)
# Strict boundary: >= means at-limit is also exceeded
if tokens_used + estimated_tokens >= token_limit and token_limit > 0:
raise TokenQuotaExceededError(

View file

@ -300,3 +300,4 @@ fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
current_active_user = fastapi_users.current_user(active=True)
current_optional_user = fastapi_users.current_user(active=True, optional=True)
current_superuser = fastapi_users.current_user(active=True, superuser=True)