mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-11 16:52:38 +02:00
Epic 5 Complete: Billing, Subscriptions, and Admin Features
Resolve all 5 deferred items from Epic 5 adversarial code review: - Migration 124: Add CASCADE to subscriptionstatus enum drop (prevent orphaned references) - Stripe rate limiting: In-memory per-user limiter (20 calls/60s) on verify-checkout-session - Subscription request cooldown: 24h cooldown before resubmitting rejected requests - Token reset date: Initialize on first subscription activation - Checkout URL validation: Confirmed HTTPS-only (Stripe always returns HTTPS) Implement Story 5.4 (Usage Tracking & Rate Limit Enforcement): - Page quota pre-check at HTTP upload layer - Extend UserRead schema with token quota fields - Frontend 402 error handling in document upload - Quota indicator in dashboard sidebar Story 5.5 (Admin Seed & Approval Flow): - Seed admin user migration with default credentials warning - Subscription approval/rejection routes with admin guard - 24h rejection cooldown enforcement Story 5.6 (Admin-Only Model Config): - Global model config visible across all search spaces - Per-search-space model configs with user access control - Superuser CRUD for global configs Additional fixes from code review: - PageLimitService: PAST_DUE subscriptions enforce free-tier limits - TokenQuotaService: PAST_DUE subscriptions enforce free-tier limits - Config routes: Fixed user_id.is_(None) filter on mutation endpoints - Stripe webhook: Added guard against silent plan downgrade on unrecognized price_id All changes formatted with Ruff (Python) and Biome (TypeScript). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
20c4f128bb
commit
4eb6ed18d6
41 changed files with 1771 additions and 318 deletions
|
|
@ -24,6 +24,7 @@ from __future__ import annotations
|
|||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "124"
|
||||
|
|
@ -33,17 +34,35 @@ depends_on: str | Sequence[str] | None = None
|
|||
|
||||
# Create the enum type so SQLAlchemy's create_type=False works at runtime
|
||||
subscriptionstatus_enum = sa.Enum(
|
||||
"free", "active", "canceled", "past_due",
|
||||
"free",
|
||||
"active",
|
||||
"canceled",
|
||||
"past_due",
|
||||
name="subscriptionstatus",
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create the PostgreSQL enum type first
|
||||
subscriptionstatus_enum.create(op.get_bind(), checkfirst=True)
|
||||
# Drop any pre-existing subscriptionstatus enum (e.g. uppercase version created by
|
||||
# SQLAlchemy's create_all() during early development) so we can create it with
|
||||
# the correct lowercase values. Safe to drop here because no column uses it yet.
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("DROP TYPE IF EXISTS subscriptionstatus CASCADE"))
|
||||
# Create the PostgreSQL enum type with lowercase values
|
||||
subscriptionstatus_enum.create(conn, checkfirst=False)
|
||||
|
||||
op.add_column("user", sa.Column("monthly_token_limit", sa.Integer(), nullable=False, server_default="100000"))
|
||||
op.add_column("user", sa.Column("tokens_used_this_month", sa.Integer(), nullable=False, server_default="0"))
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"monthly_token_limit", sa.Integer(), nullable=False, server_default="100000"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"tokens_used_this_month", sa.Integer(), nullable=False, server_default="0"
|
||||
),
|
||||
)
|
||||
op.add_column("user", sa.Column("token_reset_date", sa.Date(), nullable=True))
|
||||
op.add_column(
|
||||
"user",
|
||||
|
|
@ -54,12 +73,23 @@ def upgrade() -> None:
|
|||
server_default="free",
|
||||
),
|
||||
)
|
||||
op.add_column("user", sa.Column("plan_id", sa.String(50), nullable=False, server_default="free"))
|
||||
op.add_column("user", sa.Column("stripe_customer_id", sa.String(255), nullable=True))
|
||||
op.add_column("user", sa.Column("stripe_subscription_id", sa.String(255), nullable=True))
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("plan_id", sa.String(50), nullable=False, server_default="free"),
|
||||
)
|
||||
op.add_column(
|
||||
"user", sa.Column("stripe_customer_id", sa.String(255), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"user", sa.Column("stripe_subscription_id", sa.String(255), nullable=True)
|
||||
)
|
||||
|
||||
op.create_unique_constraint("uq_user_stripe_customer_id", "user", ["stripe_customer_id"])
|
||||
op.create_unique_constraint("uq_user_stripe_subscription_id", "user", ["stripe_subscription_id"])
|
||||
op.create_unique_constraint(
|
||||
"uq_user_stripe_customer_id", "user", ["stripe_customer_id"]
|
||||
)
|
||||
op.create_unique_constraint(
|
||||
"uq_user_stripe_subscription_id", "user", ["stripe_subscription_id"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
|
|
|||
207
surfsense_backend/alembic/versions/126_seed_admin_user.py
Normal file
207
surfsense_backend/alembic/versions/126_seed_admin_user.py
Normal file
|
|
@ -0,0 +1,207 @@
|
|||
"""126_seed_admin_user
|
||||
|
||||
Revision ID: 126
|
||||
Revises: 125
|
||||
Create Date: 2026-04-15
|
||||
|
||||
Seeds one admin user on fresh installs (no-op if any user already exists).
|
||||
Credentials are overridable via env vars:
|
||||
ADMIN_EMAIL (default: admin@surfsense.local)
|
||||
ADMIN_PASSWORD (default: Admin@SurfSense1)
|
||||
|
||||
Admin is created with:
|
||||
- is_superuser = TRUE, is_active = TRUE, is_verified = TRUE
|
||||
- subscription_status = 'active', plan_id = 'pro_yearly'
|
||||
- monthly_token_limit = 1_000_000, pages_limit = 5000
|
||||
- A default search space, roles, membership, and prompt defaults
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "126"
|
||||
down_revision: str | None = "125"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def _hash_password(password: str) -> str:
|
||||
"""Hash password using argon2-cffi (installed as a fastapi-users dependency)."""
|
||||
from argon2 import PasswordHasher
|
||||
|
||||
ph = PasswordHasher()
|
||||
return ph.hash(password)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Only seed when the database is empty
|
||||
result = conn.execute(sa.text('SELECT 1 FROM "user" LIMIT 1'))
|
||||
if result.fetchone() is not None:
|
||||
return # Users already exist — skip seed
|
||||
|
||||
admin_email = os.environ.get("ADMIN_EMAIL", "admin@surfsense.local")
|
||||
admin_password = os.environ.get("ADMIN_PASSWORD", "Admin@SurfSense1")
|
||||
if not os.environ.get("ADMIN_PASSWORD"):
|
||||
print(
|
||||
"\n⚠️ WARNING: ADMIN_PASSWORD env var not set. "
|
||||
"Using default password 'Admin@SurfSense1'. "
|
||||
"Change this immediately after first login!\n"
|
||||
)
|
||||
hashed_pw = _hash_password(admin_password)
|
||||
admin_id = str(uuid.uuid4())
|
||||
|
||||
# 1. Insert admin user
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO "user" (
|
||||
id, email, hashed_password,
|
||||
is_active, is_superuser, is_verified,
|
||||
subscription_status, plan_id,
|
||||
monthly_token_limit, pages_limit, pages_used,
|
||||
tokens_used_this_month
|
||||
) VALUES (
|
||||
:id, :email, :hashed_password,
|
||||
TRUE, TRUE, TRUE,
|
||||
'active', 'pro_yearly',
|
||||
1000000, 5000, 0,
|
||||
0
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": admin_id,
|
||||
"email": admin_email,
|
||||
"hashed_password": hashed_pw,
|
||||
},
|
||||
)
|
||||
|
||||
# 2. Insert default search space for admin (only required columns; defaults handle the rest)
|
||||
search_space_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO searchspaces (name, description, citations_enabled, user_id, created_at)
|
||||
VALUES ('My Search Space', 'Your personal search space', TRUE, :user_id, now())
|
||||
RETURNING id
|
||||
"""
|
||||
),
|
||||
{"user_id": admin_id},
|
||||
)
|
||||
search_space_id = search_space_result.fetchone()[0]
|
||||
|
||||
# 3. Insert default roles for the search space
|
||||
owner_role_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO search_space_roles
|
||||
(name, description, permissions, is_default, is_system_role, search_space_id, created_at)
|
||||
VALUES (
|
||||
'Owner', 'Full access to all search space resources and settings',
|
||||
ARRAY['*'], FALSE, TRUE, :ss_id, now()
|
||||
)
|
||||
RETURNING id
|
||||
"""
|
||||
),
|
||||
{"ss_id": search_space_id},
|
||||
)
|
||||
owner_role_id = owner_role_result.fetchone()[0]
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO search_space_roles
|
||||
(name, description, permissions, is_default, is_system_role, search_space_id, created_at)
|
||||
VALUES
|
||||
(
|
||||
'Editor',
|
||||
'Can create and update content (no delete, role management, or settings access)',
|
||||
ARRAY[
|
||||
'documents:create','documents:read','documents:update',
|
||||
'chats:create','chats:read','chats:update',
|
||||
'comments:create','comments:read',
|
||||
'llm_configs:create','llm_configs:read','llm_configs:update',
|
||||
'podcasts:create','podcasts:read','podcasts:update',
|
||||
'video_presentations:create','video_presentations:read','video_presentations:update',
|
||||
'image_generations:create','image_generations:read',
|
||||
'vision_configs:create','vision_configs:read',
|
||||
'connectors:create','connectors:read','connectors:update',
|
||||
'logs:read', 'members:invite'
|
||||
],
|
||||
TRUE, TRUE, :ss_id, now()
|
||||
),
|
||||
(
|
||||
'Viewer', 'Read-only access to search space resources',
|
||||
ARRAY[
|
||||
'documents:read','chats:read','comments:read',
|
||||
'llm_configs:read','podcasts:read','video_presentations:read',
|
||||
'image_generations:read','vision_configs:read','connectors:read','logs:read'
|
||||
],
|
||||
FALSE, TRUE, :ss_id, now()
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"ss_id": search_space_id},
|
||||
)
|
||||
|
||||
# 4. Insert owner membership
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO search_space_memberships
|
||||
(user_id, search_space_id, role_id, is_owner, joined_at, created_at)
|
||||
VALUES (:user_id, :ss_id, :role_id, TRUE, now(), now())
|
||||
"""
|
||||
),
|
||||
{"user_id": admin_id, "ss_id": search_space_id, "role_id": owner_role_id},
|
||||
)
|
||||
|
||||
# 5. Insert default prompts (same as migration 114 but just for admin)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO prompts
|
||||
(user_id, default_prompt_slug, name, prompt, mode, version, is_public, created_at)
|
||||
VALUES
|
||||
(:uid, 'fix-grammar', 'Fix grammar',
|
||||
'Fix the grammar and spelling in the following text. Return only the corrected text, nothing else.\n\n{selection}',
|
||||
'transform'::prompt_mode, 1, false, now()),
|
||||
(:uid, 'make-shorter', 'Make shorter',
|
||||
'Make the following text more concise while preserving its meaning. Return only the shortened text, nothing else.\n\n{selection}',
|
||||
'transform'::prompt_mode, 1, false, now()),
|
||||
(:uid, 'translate', 'Translate',
|
||||
'Translate the following text to English. If it is already in English, translate it to French. Return only the translation, nothing else.\n\n{selection}',
|
||||
'transform'::prompt_mode, 1, false, now()),
|
||||
(:uid, 'rewrite', 'Rewrite',
|
||||
'Rewrite the following text to improve clarity and readability. Return only the rewritten text, nothing else.\n\n{selection}',
|
||||
'transform'::prompt_mode, 1, false, now()),
|
||||
(:uid, 'summarize', 'Summarize',
|
||||
'Summarize the following text concisely. Return only the summary, nothing else.\n\n{selection}',
|
||||
'transform'::prompt_mode, 1, false, now()),
|
||||
(:uid, 'explain', 'Explain',
|
||||
'Explain the following text in simple terms:\n\n{selection}',
|
||||
'explore'::prompt_mode, 1, false, now()),
|
||||
(:uid, 'ask-knowledge-base', 'Ask my knowledge base',
|
||||
'Search my knowledge base for information related to:\n\n{selection}',
|
||||
'explore'::prompt_mode, 1, false, now()),
|
||||
(:uid, 'look-up-web', 'Look up on the web',
|
||||
'Search the web for information about:\n\n{selection}',
|
||||
'explore'::prompt_mode, 1, false, now())
|
||||
ON CONFLICT (user_id, default_prompt_slug) DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"uid": admin_id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Intentional no-op: never delete users on downgrade
|
||||
pass
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
"""127_add_subscription_requests_table
|
||||
|
||||
Revision ID: 127
|
||||
Revises: 126
|
||||
Create Date: 2026-04-15
|
||||
|
||||
Adds the subscription_requests table for admin-approval flow when Stripe
|
||||
is not configured. Users submit a subscription request; superusers can
|
||||
approve/reject it from the admin panel.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision: str = "127"
|
||||
down_revision: str | None = "126"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
# Drop any pre-existing enum (e.g. uppercase version from old create_all())
|
||||
conn.execute(sa.text("DROP TYPE IF EXISTS subscriptionrequeststatus"))
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"CREATE TYPE subscriptionrequeststatus AS ENUM ('pending', 'approved', 'rejected')"
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
CREATE TABLE subscription_requests (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
plan_id VARCHAR(50) NOT NULL,
|
||||
status subscriptionrequeststatus NOT NULL DEFAULT 'pending',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
approved_at TIMESTAMPTZ,
|
||||
approved_by UUID REFERENCES "user"(id)
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"CREATE INDEX ix_subscription_requests_user_id ON subscription_requests (user_id)"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS subscription_requests"))
|
||||
conn.execute(sa.text("DROP TYPE IF EXISTS subscriptionrequeststatus"))
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
"""128_make_model_config_user_id_nullable
|
||||
|
||||
Revision ID: 128
|
||||
Revises: 127
|
||||
Create Date: 2026-04-15
|
||||
|
||||
Makes user_id nullable on the three model-config tables so that admin-created
|
||||
(superuser-owned) configurations have user_id = NULL and are visible to all
|
||||
members of the search space.
|
||||
|
||||
Tables affected:
|
||||
- new_llm_configs
|
||||
- image_generation_configs
|
||||
- vision_llm_configs
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision: str = "128"
|
||||
down_revision: str | None = "127"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("new_llm_configs", "user_id", existing_type=sa.UUID(), nullable=True)
|
||||
op.alter_column("image_generation_configs", "user_id", existing_type=sa.UUID(), nullable=True)
|
||||
op.alter_column("vision_llm_configs", "user_id", existing_type=sa.UUID(), nullable=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Null out orphaned rows before re-adding NOT NULL (safety guard)
|
||||
for table in ("new_llm_configs", "image_generation_configs", "vision_llm_configs"):
|
||||
# If any rows have user_id=NULL we cannot restore NOT NULL — delete them
|
||||
conn.execute(sa.text(f'DELETE FROM "{table}" WHERE user_id IS NULL'))
|
||||
|
||||
op.alter_column("new_llm_configs", "user_id", existing_type=sa.UUID(), nullable=False)
|
||||
op.alter_column("image_generation_configs", "user_id", existing_type=sa.UUID(), nullable=False)
|
||||
op.alter_column("vision_llm_configs", "user_id", existing_type=sa.UUID(), nullable=False)
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
"""129_make_model_config_search_space_id_nullable
|
||||
|
||||
Revision ID: 129
|
||||
Revises: 128
|
||||
Create Date: 2026-04-15
|
||||
|
||||
Makes search_space_id nullable on the three model-config tables so that
|
||||
admin-created (superuser-owned) configurations have search_space_id = NULL
|
||||
and are visible to ALL users across ALL search spaces.
|
||||
|
||||
Tables affected:
|
||||
- new_llm_configs
|
||||
- image_generation_configs
|
||||
- vision_llm_configs
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision: str = "129"
|
||||
down_revision: str | None = "128"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"new_llm_configs", "search_space_id", existing_type=sa.Integer(), nullable=True
|
||||
)
|
||||
op.alter_column(
|
||||
"image_generation_configs",
|
||||
"search_space_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=True,
|
||||
)
|
||||
op.alter_column(
|
||||
"vision_llm_configs",
|
||||
"search_space_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Delete global configs (search_space_id IS NULL) before restoring NOT NULL
|
||||
for table in ("new_llm_configs", "image_generation_configs", "vision_llm_configs"):
|
||||
conn.execute(sa.text(f'DELETE FROM "{table}" WHERE search_space_id IS NULL'))
|
||||
|
||||
op.alter_column(
|
||||
"new_llm_configs", "search_space_id", existing_type=sa.Integer(), nullable=False
|
||||
)
|
||||
op.alter_column(
|
||||
"image_generation_configs",
|
||||
"search_space_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"vision_llm_configs",
|
||||
"search_space_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False,
|
||||
)
|
||||
|
|
@ -1197,15 +1197,15 @@ class ImageGenerationConfig(BaseModel, TimestampMixin):
|
|||
|
||||
# Relationships
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
search_space = relationship(
|
||||
"SearchSpace", back_populates="image_generation_configs"
|
||||
)
|
||||
|
||||
# User who created this config
|
||||
# User who created this config (NULL for admin-created global configs)
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
user = relationship("User", back_populates="image_generation_configs")
|
||||
|
||||
|
|
@ -1227,12 +1227,13 @@ class VisionLLMConfig(BaseModel, TimestampMixin):
|
|||
litellm_params = Column(JSON, nullable=True, default={})
|
||||
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="vision_llm_configs")
|
||||
|
||||
# User who created this config (NULL for admin-created global configs)
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
user = relationship("User", back_populates="vision_llm_configs")
|
||||
|
||||
|
|
@ -1535,13 +1536,13 @@ class NewLLMConfig(BaseModel, TimestampMixin):
|
|||
|
||||
# === Relationships ===
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="new_llm_configs")
|
||||
|
||||
# User who created this config
|
||||
# User who created this config (NULL for admin-created global configs)
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
user = relationship("User", back_populates="new_llm_configs")
|
||||
|
||||
|
|
@ -1683,6 +1684,56 @@ class PagePurchase(Base, TimestampMixin):
|
|||
user = relationship("User", back_populates="page_purchases")
|
||||
|
||||
|
||||
class SubscriptionRequestStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
REJECTED = "rejected"
|
||||
|
||||
|
||||
class SubscriptionRequest(Base):
|
||||
"""Tracks subscription upgrade requests when Stripe is not configured (admin-approval flow)."""
|
||||
|
||||
__tablename__ = "subscription_requests"
|
||||
__allow_unmapped__ = True
|
||||
|
||||
id = Column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
server_default=text("gen_random_uuid()"),
|
||||
)
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
plan_id = Column(String(50), nullable=False)
|
||||
status = Column(
|
||||
SQLAlchemyEnum(
|
||||
SubscriptionRequestStatus,
|
||||
name="subscriptionrequeststatus",
|
||||
create_type=False,
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
nullable=False,
|
||||
default=SubscriptionRequestStatus.PENDING,
|
||||
server_default="pending",
|
||||
)
|
||||
created_at = Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
server_default=text("now()"),
|
||||
)
|
||||
approved_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
approved_by = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id"),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
user = relationship("User", foreign_keys=[user_id], back_populates="subscription_requests")
|
||||
|
||||
|
||||
class SearchSpaceRole(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Custom roles that can be defined per search space.
|
||||
|
|
@ -1953,6 +2004,12 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
subscription_requests = relationship(
|
||||
"SubscriptionRequest",
|
||||
foreign_keys="SubscriptionRequest.user_id",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Page usage tracking for ETL services
|
||||
pages_limit = Column(
|
||||
|
|
@ -1968,7 +2025,7 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
tokens_used_this_month = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
token_reset_date = Column(Date, nullable=True)
|
||||
subscription_status = Column(
|
||||
SQLAlchemyEnum(SubscriptionStatus, name="subscriptionstatus", create_type=True),
|
||||
SQLAlchemyEnum(SubscriptionStatus, name="subscriptionstatus", create_type=True, values_callable=lambda x: [e.value for e in x]),
|
||||
nullable=False,
|
||||
default=SubscriptionStatus.FREE,
|
||||
server_default="free",
|
||||
|
|
@ -2082,6 +2139,12 @@ else:
|
|||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
subscription_requests = relationship(
|
||||
"SubscriptionRequest",
|
||||
foreign_keys="SubscriptionRequest.user_id",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Page usage tracking for ETL services
|
||||
pages_limit = Column(
|
||||
|
|
@ -2097,7 +2160,7 @@ else:
|
|||
tokens_used_this_month = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
token_reset_date = Column(Date, nullable=True)
|
||||
subscription_status = Column(
|
||||
SQLAlchemyEnum(SubscriptionStatus, name="subscriptionstatus", create_type=True),
|
||||
SQLAlchemyEnum(SubscriptionStatus, name="subscriptionstatus", create_type=True, values_callable=lambda x: [e.value for e in x]),
|
||||
nullable=False,
|
||||
default=SubscriptionStatus.FREE,
|
||||
server_default="free",
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ from .sandbox_routes import router as sandbox_router
|
|||
from .search_source_connectors_routes import router as search_source_connectors_router
|
||||
from .search_spaces_routes import router as search_spaces_router
|
||||
from .slack_add_connector_route import router as slack_add_connector_router
|
||||
from .admin_routes import router as admin_router
|
||||
from .stripe_routes import router as stripe_router
|
||||
from .surfsense_docs_routes import router as surfsense_docs_router
|
||||
from .teams_add_connector_route import router as teams_add_connector_router
|
||||
|
|
@ -100,6 +101,7 @@ router.include_router(notifications_router) # Notifications with Zero sync
|
|||
router.include_router(composio_router) # Composio OAuth and toolkit management
|
||||
router.include_router(public_chat_router) # Public chat sharing and cloning
|
||||
router.include_router(incentive_tasks_router) # Incentive tasks for earning free pages
|
||||
router.include_router(admin_router) # Superuser admin operations
|
||||
router.include_router(stripe_router) # Stripe checkout for additional page packs
|
||||
router.include_router(youtube_router) # YouTube playlist resolution
|
||||
router.include_router(prompts_router)
|
||||
|
|
|
|||
212
surfsense_backend/app/routes/admin_routes.py
Normal file
212
surfsense_backend/app/routes/admin_routes.py
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
"""Admin routes — superuser-only operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
SubscriptionRequest,
|
||||
SubscriptionRequestStatus,
|
||||
SubscriptionStatus,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_superuser
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SubscriptionRequestItem(BaseModel):
|
||||
id: uuid.UUID
|
||||
user_id: uuid.UUID
|
||||
user_email: str
|
||||
plan_id: str
|
||||
status: str
|
||||
created_at: datetime
|
||||
approved_at: datetime | None = None
|
||||
approved_by: uuid.UUID | None = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# List pending subscription requests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/subscription-requests",
|
||||
response_model=list[SubscriptionRequestItem],
|
||||
)
|
||||
async def list_subscription_requests(
|
||||
admin: User = Depends(current_superuser),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
) -> list[SubscriptionRequestItem]:
|
||||
"""Return all pending subscription requests."""
|
||||
result = await db_session.execute(
|
||||
select(SubscriptionRequest)
|
||||
.where(SubscriptionRequest.status == SubscriptionRequestStatus.PENDING)
|
||||
.order_by(SubscriptionRequest.created_at.asc())
|
||||
)
|
||||
requests = result.scalars().all()
|
||||
|
||||
# Collect user IDs and batch-load to avoid N+1
|
||||
user_ids = [req.user_id for req in requests]
|
||||
email_map: dict[uuid.UUID, str] = {}
|
||||
if user_ids:
|
||||
user_rows = await db_session.execute(select(User).where(User.id.in_(user_ids)))
|
||||
for u in user_rows.scalars():
|
||||
email_map[u.id] = u.email
|
||||
|
||||
items: list[SubscriptionRequestItem] = [
|
||||
SubscriptionRequestItem(
|
||||
id=req.id,
|
||||
user_id=req.user_id,
|
||||
user_email=email_map.get(req.user_id, "<deleted>"),
|
||||
plan_id=req.plan_id,
|
||||
status=req.status.value,
|
||||
created_at=req.created_at,
|
||||
approved_at=req.approved_at,
|
||||
approved_by=req.approved_by,
|
||||
)
|
||||
for req in requests
|
||||
]
|
||||
return items
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Approve a subscription request
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post(
|
||||
"/subscription-requests/{request_id}/approve",
|
||||
response_model=SubscriptionRequestItem,
|
||||
)
|
||||
async def approve_subscription_request(
|
||||
request_id: uuid.UUID,
|
||||
admin: User = Depends(current_superuser),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
) -> SubscriptionRequestItem:
|
||||
"""Approve a pending subscription request and activate the user's subscription."""
|
||||
result = await db_session.execute(
|
||||
select(SubscriptionRequest)
|
||||
.where(SubscriptionRequest.id == request_id)
|
||||
.with_for_update()
|
||||
)
|
||||
req = result.scalar_one_or_none()
|
||||
if req is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Subscription request not found.",
|
||||
)
|
||||
if req.status != SubscriptionRequestStatus.PENDING:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Request is already {req.status.value}.",
|
||||
)
|
||||
|
||||
user_result = await db_session.execute(
|
||||
select(User).where(User.id == req.user_id).with_for_update()
|
||||
)
|
||||
req_user = user_result.scalar_one_or_none()
|
||||
if req_user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found."
|
||||
)
|
||||
|
||||
# Activate subscription
|
||||
plan_limits = config.PLAN_LIMITS.get(req.plan_id, config.PLAN_LIMITS["free"])
|
||||
req_user.subscription_status = SubscriptionStatus.ACTIVE
|
||||
req_user.plan_id = req.plan_id
|
||||
req_user.monthly_token_limit = plan_limits["monthly_token_limit"]
|
||||
req_user.pages_limit = max(req_user.pages_used or 0, plan_limits["pages_limit"])
|
||||
req_user.tokens_used_this_month = 0
|
||||
req_user.token_reset_date = datetime.now(UTC).date()
|
||||
|
||||
# Mark request approved
|
||||
now = datetime.now(UTC)
|
||||
req.status = SubscriptionRequestStatus.APPROVED
|
||||
req.approved_at = now
|
||||
req.approved_by = admin.id
|
||||
|
||||
await db_session.commit()
|
||||
await db_session.refresh(req)
|
||||
|
||||
user_result2 = await db_session.execute(select(User).where(User.id == req.user_id))
|
||||
req_user2 = user_result2.scalar_one_or_none()
|
||||
email = req_user2.email if req_user2 else "<deleted>"
|
||||
|
||||
return SubscriptionRequestItem(
|
||||
id=req.id,
|
||||
user_id=req.user_id,
|
||||
user_email=email,
|
||||
plan_id=req.plan_id,
|
||||
status=req.status.value,
|
||||
created_at=req.created_at,
|
||||
approved_at=req.approved_at,
|
||||
approved_by=req.approved_by,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reject a subscription request
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post(
|
||||
"/subscription-requests/{request_id}/reject",
|
||||
response_model=SubscriptionRequestItem,
|
||||
)
|
||||
async def reject_subscription_request(
|
||||
request_id: uuid.UUID,
|
||||
admin: User = Depends(current_superuser),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
) -> SubscriptionRequestItem:
|
||||
"""Reject a pending subscription request."""
|
||||
result = await db_session.execute(
|
||||
select(SubscriptionRequest).where(SubscriptionRequest.id == request_id)
|
||||
)
|
||||
req = result.scalar_one_or_none()
|
||||
if req is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Subscription request not found.",
|
||||
)
|
||||
if req.status != SubscriptionRequestStatus.PENDING:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Request is already {req.status.value}.",
|
||||
)
|
||||
|
||||
req.status = SubscriptionRequestStatus.REJECTED
|
||||
|
||||
await db_session.commit()
|
||||
await db_session.refresh(req)
|
||||
|
||||
user_result = await db_session.execute(select(User).where(User.id == req.user_id))
|
||||
req_user = user_result.scalar_one_or_none()
|
||||
email = req_user.email if req_user else "<deleted>"
|
||||
|
||||
return SubscriptionRequestItem(
|
||||
id=req.id,
|
||||
user_id=req.user_id,
|
||||
user_email=email,
|
||||
plan_id=req.plan_id,
|
||||
status=req.status.value,
|
||||
created_at=req.created_at,
|
||||
approved_at=req.approved_at,
|
||||
approved_by=req.approved_by,
|
||||
)
|
||||
|
|
@ -73,6 +73,24 @@ async def create_documents(
|
|||
"You don't have permission to create documents in this search space",
|
||||
)
|
||||
|
||||
# Page quota pre-check for connector documents
|
||||
from app.services.page_limit_service import (
|
||||
PageLimitExceededError,
|
||||
PageLimitService,
|
||||
)
|
||||
|
||||
estimated_pages = len(request.content) # 1 page per document/URL
|
||||
try:
|
||||
page_service = PageLimitService(session)
|
||||
await page_service.check_page_limit(str(user.id), estimated_pages)
|
||||
except PageLimitExceededError as e:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=f"Page quota exceeded ({e.pages_used}/{e.pages_limit}). "
|
||||
f"This request requires ~{estimated_pages} pages. "
|
||||
f"Upgrade your plan for more pages.",
|
||||
) from e
|
||||
|
||||
if request.document_type == DocumentType.EXTENSION:
|
||||
from app.tasks.celery_tasks.document_tasks import (
|
||||
process_extension_document_task,
|
||||
|
|
@ -169,6 +187,30 @@ async def create_documents_file_upload(
|
|||
f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
|
||||
)
|
||||
|
||||
# Page quota pre-check
|
||||
from app.services.page_limit_service import (
|
||||
PageLimitExceededError,
|
||||
PageLimitService,
|
||||
)
|
||||
|
||||
total_estimated_pages = sum(
|
||||
PageLimitService.estimate_pages_from_metadata(
|
||||
file.filename or "", file.size or 0
|
||||
)
|
||||
for file in files
|
||||
)
|
||||
|
||||
try:
|
||||
page_service = PageLimitService(session)
|
||||
await page_service.check_page_limit(str(user.id), total_estimated_pages)
|
||||
except PageLimitExceededError as e:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=f"Page quota exceeded ({e.pages_used}/{e.pages_limit}). "
|
||||
f"This upload requires ~{total_estimated_pages} pages. "
|
||||
f"Upgrade your plan for more pages.",
|
||||
) from e
|
||||
|
||||
# ===== Read all files concurrently to avoid blocking the event loop =====
|
||||
async def _read_and_save(file: UploadFile) -> tuple[str, str, int]:
|
||||
"""Read upload content and write to temp file off the event loop."""
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from app.db import (
|
|||
from app.schemas import (
|
||||
GlobalImageGenConfigRead,
|
||||
ImageGenerationConfigCreate,
|
||||
ImageGenerationConfigPublic,
|
||||
ImageGenerationConfigRead,
|
||||
ImageGenerationConfigUpdate,
|
||||
ImageGenerationCreate,
|
||||
|
|
@ -41,7 +42,7 @@ from app.services.image_gen_router_service import (
|
|||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import current_active_user, current_superuser
|
||||
from app.utils.rbac import check_permission
|
||||
from app.utils.signed_image_urls import verify_image_token
|
||||
|
||||
|
|
@ -261,19 +262,11 @@ async def get_global_image_gen_configs(
|
|||
async def create_image_gen_config(
|
||||
config_data: ImageGenerationConfigCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
user: User = Depends(current_superuser),
|
||||
):
|
||||
"""Create a new image generation config for a search space."""
|
||||
"""Create a new image generation config for a search space. Superuser only."""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
config_data.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_CREATE.value,
|
||||
"You don't have permission to create image generation configs in this search space",
|
||||
)
|
||||
|
||||
db_config = ImageGenerationConfig(**config_data.model_dump(), user_id=user.id)
|
||||
db_config = ImageGenerationConfig(**config_data.model_dump(), user_id=None)
|
||||
session.add(db_config)
|
||||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
|
|
@ -289,7 +282,9 @@ async def create_image_gen_config(
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/image-generation-configs", response_model=list[ImageGenerationConfigRead])
|
||||
@router.get(
|
||||
"/image-generation-configs", response_model=list[ImageGenerationConfigPublic]
|
||||
)
|
||||
async def list_image_gen_configs(
|
||||
search_space_id: int,
|
||||
skip: int = 0,
|
||||
|
|
@ -297,7 +292,7 @@ async def list_image_gen_configs(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""List image generation configs for a search space."""
|
||||
"""List image generation configs for a search space (includes global admin configs)."""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
|
|
@ -309,7 +304,10 @@ async def list_image_gen_configs(
|
|||
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig)
|
||||
.filter(ImageGenerationConfig.search_space_id == search_space_id)
|
||||
.filter(
|
||||
(ImageGenerationConfig.search_space_id == search_space_id)
|
||||
| (ImageGenerationConfig.search_space_id == None) # noqa: E711
|
||||
)
|
||||
.order_by(ImageGenerationConfig.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
|
|
@ -367,25 +365,20 @@ async def update_image_gen_config(
|
|||
config_id: int,
|
||||
update_data: ImageGenerationConfigUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
user: User = Depends(current_superuser),
|
||||
):
|
||||
"""Update an existing image generation config."""
|
||||
"""Update an existing image generation config. Superuser only."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
||||
select(ImageGenerationConfig).filter(
|
||||
ImageGenerationConfig.id == config_id,
|
||||
ImageGenerationConfig.user_id.is_(None),
|
||||
)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_CREATE.value,
|
||||
"You don't have permission to update image generation configs in this search space",
|
||||
)
|
||||
|
||||
for key, value in update_data.model_dump(exclude_unset=True).items():
|
||||
setattr(db_config, key, value)
|
||||
|
||||
|
|
@ -407,25 +400,20 @@ async def update_image_gen_config(
|
|||
async def delete_image_gen_config(
|
||||
config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
user: User = Depends(current_superuser),
|
||||
):
|
||||
"""Delete an image generation config."""
|
||||
"""Delete an image generation config. Superuser only."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
||||
select(ImageGenerationConfig).filter(
|
||||
ImageGenerationConfig.id == config_id,
|
||||
ImageGenerationConfig.user_id.is_(None),
|
||||
)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_DELETE.value,
|
||||
"You don't have permission to delete image generation configs in this search space",
|
||||
)
|
||||
|
||||
await session.delete(db_config)
|
||||
await session.commit()
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -25,11 +25,12 @@ from app.schemas import (
|
|||
DefaultSystemInstructionsResponse,
|
||||
GlobalNewLLMConfigRead,
|
||||
NewLLMConfigCreate,
|
||||
NewLLMConfigPublic,
|
||||
NewLLMConfigRead,
|
||||
NewLLMConfigUpdate,
|
||||
)
|
||||
from app.services.llm_service import validate_llm_config
|
||||
from app.users import current_active_user
|
||||
from app.users import current_active_user, current_superuser
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -117,22 +118,13 @@ async def get_global_new_llm_configs(
|
|||
async def create_new_llm_config(
|
||||
config_data: NewLLMConfigCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
user: User = Depends(current_superuser),
|
||||
):
|
||||
"""
|
||||
Create a new NewLLMConfig for a search space.
|
||||
Requires LLM_CONFIGS_CREATE permission.
|
||||
Superuser only — configs are shared with all search space members.
|
||||
"""
|
||||
try:
|
||||
# Verify user has permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
config_data.search_space_id,
|
||||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
"You don't have permission to create LLM configurations in this search space",
|
||||
)
|
||||
|
||||
# Validate the LLM configuration by making a test API call
|
||||
is_valid, error_message = await validate_llm_config(
|
||||
provider=config_data.provider.value,
|
||||
|
|
@ -149,8 +141,8 @@ async def create_new_llm_config(
|
|||
detail=f"Invalid LLM configuration: {error_message}",
|
||||
)
|
||||
|
||||
# Create the config with user association
|
||||
db_config = NewLLMConfig(**config_data.model_dump(), user_id=user.id)
|
||||
# Create the config as admin-owned (user_id=None means shared with all space members)
|
||||
db_config = NewLLMConfig(**config_data.model_dump(), user_id=None)
|
||||
session.add(db_config)
|
||||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
|
|
@ -167,7 +159,7 @@ async def create_new_llm_config(
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/new-llm-configs", response_model=list[NewLLMConfigRead])
|
||||
@router.get("/new-llm-configs", response_model=list[NewLLMConfigPublic])
|
||||
async def list_new_llm_configs(
|
||||
search_space_id: int,
|
||||
skip: int = 0,
|
||||
|
|
@ -176,11 +168,11 @@ async def list_new_llm_configs(
|
|||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get all NewLLMConfigs for a search space.
|
||||
Get all NewLLMConfigs for a search space (includes global admin configs).
|
||||
Requires LLM_CONFIGS_READ permission.
|
||||
"""
|
||||
try:
|
||||
# Verify user has permission
|
||||
# Verify user has permission for their space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
|
|
@ -191,7 +183,10 @@ async def list_new_llm_configs(
|
|||
|
||||
result = await session.execute(
|
||||
select(NewLLMConfig)
|
||||
.filter(NewLLMConfig.search_space_id == search_space_id)
|
||||
.filter(
|
||||
(NewLLMConfig.search_space_id == search_space_id)
|
||||
| (NewLLMConfig.search_space_id == None) # noqa: E711
|
||||
)
|
||||
.order_by(NewLLMConfig.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
|
|
@ -268,30 +263,23 @@ async def update_new_llm_config(
|
|||
config_id: int,
|
||||
update_data: NewLLMConfigUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
user: User = Depends(current_superuser),
|
||||
):
|
||||
"""
|
||||
Update an existing NewLLMConfig.
|
||||
Requires LLM_CONFIGS_UPDATE permission.
|
||||
Superuser only.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
|
||||
select(NewLLMConfig).filter(
|
||||
NewLLMConfig.id == config_id, NewLLMConfig.user_id.is_(None)
|
||||
)
|
||||
)
|
||||
config = result.scalars().first()
|
||||
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="Configuration not found")
|
||||
|
||||
# Verify user has permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
config.search_space_id,
|
||||
Permission.LLM_CONFIGS_UPDATE.value,
|
||||
"You don't have permission to update LLM configurations in this search space",
|
||||
)
|
||||
|
||||
update_dict = update_data.model_dump(exclude_unset=True)
|
||||
|
||||
# If updating LLM settings, validate them
|
||||
|
|
@ -360,30 +348,23 @@ async def update_new_llm_config(
|
|||
async def delete_new_llm_config(
|
||||
config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
user: User = Depends(current_superuser),
|
||||
):
|
||||
"""
|
||||
Delete a NewLLMConfig.
|
||||
Requires LLM_CONFIGS_DELETE permission.
|
||||
Superuser only.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
|
||||
select(NewLLMConfig).filter(
|
||||
NewLLMConfig.id == config_id, NewLLMConfig.user_id.is_(None)
|
||||
)
|
||||
)
|
||||
config = result.scalars().first()
|
||||
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="Configuration not found")
|
||||
|
||||
# Verify user has permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
config.search_space_id,
|
||||
Permission.LLM_CONFIGS_DELETE.value,
|
||||
"You don't have permission to delete LLM configurations in this search space",
|
||||
)
|
||||
|
||||
await session.delete(config)
|
||||
await session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,8 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from collections import defaultdict
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
|
|
@ -13,7 +14,15 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from stripe import SignatureVerificationError, StripeClient, StripeError
|
||||
|
||||
from app.config import config
|
||||
from app.db import PagePurchase, PagePurchaseStatus, SubscriptionStatus, User, get_async_session
|
||||
from app.db import (
|
||||
PagePurchase,
|
||||
PagePurchaseStatus,
|
||||
SubscriptionRequest,
|
||||
SubscriptionRequestStatus,
|
||||
SubscriptionStatus,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas.stripe import (
|
||||
CreateCheckoutSessionRequest,
|
||||
CreateCheckoutSessionResponse,
|
||||
|
|
@ -30,6 +39,28 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
router = APIRouter(prefix="/stripe", tags=["stripe"])
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Simple in-memory rate limiter for verify-checkout-session (20 calls/60 s)
|
||||
# Not persistent across workers — acceptable for the low-risk, low-volume
|
||||
# nature of this endpoint.
|
||||
# ---------------------------------------------------------------------------
|
||||
_VERIFY_SESSION_WINDOW_SECS = 60
|
||||
_VERIFY_SESSION_MAX_CALLS = 20
|
||||
_verify_session_calls: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
|
||||
def _check_verify_session_rate_limit(user_id: str) -> None:
|
||||
now = datetime.now(UTC).timestamp()
|
||||
cutoff = now - _VERIFY_SESSION_WINDOW_SECS
|
||||
calls = [t for t in _verify_session_calls[user_id] if t > cutoff]
|
||||
if len(calls) >= _VERIFY_SESSION_MAX_CALLS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Too many requests. Try again later.",
|
||||
)
|
||||
calls.append(now)
|
||||
_verify_session_calls[user_id] = calls
|
||||
|
||||
|
||||
def get_stripe_client() -> StripeClient:
|
||||
"""Return a configured Stripe client or raise if Stripe is disabled."""
|
||||
|
|
@ -145,7 +176,10 @@ async def _get_or_create_stripe_customer(
|
|||
|
||||
try:
|
||||
customer = stripe_client.v1.customers.create(
|
||||
params={"email": locked_user.email, "metadata": {"user_id": str(locked_user.id)}}
|
||||
params={
|
||||
"email": locked_user.email,
|
||||
"metadata": {"user_id": str(locked_user.id)},
|
||||
}
|
||||
)
|
||||
except StripeError as exc:
|
||||
logger.exception("Failed to create Stripe customer for user %s", locked_user.id)
|
||||
|
|
@ -288,6 +322,7 @@ async def _fulfill_completed_purchase(
|
|||
# Subscription event helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _get_user_by_stripe_customer_id(
|
||||
db_session: AsyncSession, customer_id: str
|
||||
) -> User | None:
|
||||
|
|
@ -344,16 +379,34 @@ async def _handle_subscription_event(
|
|||
subscription_id,
|
||||
price_id,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
except Exception:
|
||||
logger.warning("Could not parse plan from subscription %s", subscription_id)
|
||||
|
||||
if not customer_id:
|
||||
logger.error("Subscription event missing customer ID for subscription %s", subscription_id)
|
||||
logger.error(
|
||||
"Subscription event missing customer ID for subscription %s",
|
||||
subscription_id,
|
||||
)
|
||||
return StripeWebhookResponse()
|
||||
|
||||
# Safety: never silently downgrade an active subscription to "free" due to
|
||||
# an unrecognized price ID. Return early without modifying the user.
|
||||
if (
|
||||
plan_id == "free"
|
||||
and str(getattr(subscription, "status", "")).lower() == "active"
|
||||
):
|
||||
logger.error(
|
||||
"Subscription %s is active but price ID is unrecognized — skipping update to avoid downgrade",
|
||||
subscription_id,
|
||||
)
|
||||
return StripeWebhookResponse()
|
||||
|
||||
user = await _get_user_by_stripe_customer_id(db_session, customer_id)
|
||||
if user is None:
|
||||
logger.warning("No user found for Stripe customer %s; skipping subscription event", customer_id)
|
||||
logger.warning(
|
||||
"No user found for Stripe customer %s; skipping subscription event",
|
||||
customer_id,
|
||||
)
|
||||
return StripeWebhookResponse()
|
||||
|
||||
# Map Stripe status → SubscriptionStatus enum
|
||||
|
|
@ -398,9 +451,11 @@ async def _handle_subscription_event(
|
|||
limits = config.PLAN_LIMITS.get(plan_id, config.PLAN_LIMITS["free"])
|
||||
user.monthly_token_limit = limits["monthly_token_limit"]
|
||||
|
||||
# Upgrade pages_limit on activation
|
||||
# Upgrade pages_limit on activation; reset token counter date
|
||||
if new_status == SubscriptionStatus.ACTIVE:
|
||||
user.pages_limit = max(user.pages_used, limits["pages_limit"])
|
||||
if user.token_reset_date is None:
|
||||
user.token_reset_date = datetime.now(UTC).date()
|
||||
|
||||
# Downgrade pages_limit when canceling
|
||||
if new_status == SubscriptionStatus.CANCELED:
|
||||
|
|
@ -430,18 +485,25 @@ async def _handle_invoice_payment_succeeded(
|
|||
|
||||
# Reset tokens on subscription renewals and initial subscription creation
|
||||
if billing_reason not in {"subscription_cycle", "subscription_create"}:
|
||||
logger.info("invoice.payment_succeeded billing_reason=%s; not resetting tokens", billing_reason)
|
||||
logger.info(
|
||||
"invoice.payment_succeeded billing_reason=%s; not resetting tokens",
|
||||
billing_reason,
|
||||
)
|
||||
return StripeWebhookResponse()
|
||||
|
||||
user = await _get_user_by_stripe_customer_id(db_session, customer_id)
|
||||
if user is None:
|
||||
logger.warning("No user found for Stripe customer %s; skipping token reset", customer_id)
|
||||
logger.warning(
|
||||
"No user found for Stripe customer %s; skipping token reset", customer_id
|
||||
)
|
||||
return StripeWebhookResponse()
|
||||
|
||||
user.tokens_used_this_month = 0
|
||||
user.token_reset_date = datetime.now(UTC).date()
|
||||
|
||||
logger.info("Reset tokens_used_this_month for user %s on subscription renewal", user.id)
|
||||
logger.info(
|
||||
"Reset tokens_used_this_month for user %s on subscription renewal", user.id
|
||||
)
|
||||
await db_session.commit()
|
||||
return StripeWebhookResponse()
|
||||
|
||||
|
|
@ -456,7 +518,10 @@ async def _handle_invoice_payment_failed(
|
|||
|
||||
user = await _get_user_by_stripe_customer_id(db_session, customer_id)
|
||||
if user is None:
|
||||
logger.warning("No user found for Stripe customer %s; skipping past_due update", customer_id)
|
||||
logger.warning(
|
||||
"No user found for Stripe customer %s; skipping past_due update",
|
||||
customer_id,
|
||||
)
|
||||
return StripeWebhookResponse()
|
||||
|
||||
if user.subscription_status == SubscriptionStatus.ACTIVE:
|
||||
|
|
@ -464,7 +529,11 @@ async def _handle_invoice_payment_failed(
|
|||
logger.info("Set subscription to PAST_DUE for user %s", user.id)
|
||||
await db_session.commit()
|
||||
else:
|
||||
logger.info("invoice.payment_failed for user %s already in status %s; no change", user.id, user.subscription_status)
|
||||
logger.info(
|
||||
"invoice.payment_failed for user %s already in status %s; no change",
|
||||
user.id,
|
||||
user.subscription_status,
|
||||
)
|
||||
|
||||
return StripeWebhookResponse()
|
||||
|
||||
|
|
@ -477,26 +546,43 @@ async def _activate_subscription_from_checkout(
|
|||
The full subscription lifecycle will also be handled by customer.subscription.created,
|
||||
but we activate immediately here so the user sees Pro access right after checkout.
|
||||
"""
|
||||
customer_id = _normalize_optional_string(getattr(checkout_session, "customer", None))
|
||||
subscription_id = _normalize_optional_string(getattr(checkout_session, "subscription", None))
|
||||
customer_id = _normalize_optional_string(
|
||||
getattr(checkout_session, "customer", None)
|
||||
)
|
||||
subscription_id = _normalize_optional_string(
|
||||
getattr(checkout_session, "subscription", None)
|
||||
)
|
||||
metadata = _get_metadata(checkout_session)
|
||||
plan_id_str = metadata.get("plan_id", "")
|
||||
|
||||
if not customer_id:
|
||||
logger.error("Subscription checkout session missing customer ID: %s", getattr(checkout_session, "id", ""))
|
||||
logger.error(
|
||||
"Subscription checkout session missing customer ID: %s",
|
||||
getattr(checkout_session, "id", ""),
|
||||
)
|
||||
return StripeWebhookResponse()
|
||||
|
||||
user = await _get_user_by_stripe_customer_id(db_session, customer_id)
|
||||
if user is None:
|
||||
logger.warning("No user found for Stripe customer %s; skipping subscription activation", customer_id)
|
||||
logger.warning(
|
||||
"No user found for Stripe customer %s; skipping subscription activation",
|
||||
customer_id,
|
||||
)
|
||||
return StripeWebhookResponse()
|
||||
|
||||
# Idempotency: already activated
|
||||
if user.subscription_status == SubscriptionStatus.ACTIVE and user.stripe_subscription_id == subscription_id:
|
||||
logger.info("Subscription already active for user %s; skipping activation", user.id)
|
||||
if (
|
||||
user.subscription_status == SubscriptionStatus.ACTIVE
|
||||
and user.stripe_subscription_id == subscription_id
|
||||
):
|
||||
logger.info(
|
||||
"Subscription already active for user %s; skipping activation", user.id
|
||||
)
|
||||
return StripeWebhookResponse()
|
||||
|
||||
plan_id = plan_id_str if plan_id_str in {"pro_monthly", "pro_yearly"} else "pro_monthly"
|
||||
plan_id = (
|
||||
plan_id_str if plan_id_str in {"pro_monthly", "pro_yearly"} else "pro_monthly"
|
||||
)
|
||||
limits = config.PLAN_LIMITS.get(plan_id, config.PLAN_LIMITS["pro_monthly"])
|
||||
|
||||
user.subscription_status = SubscriptionStatus.ACTIVE
|
||||
|
|
@ -512,11 +598,20 @@ async def _activate_subscription_from_checkout(
|
|||
try:
|
||||
stripe_client = get_stripe_client()
|
||||
sub_obj = stripe_client.v1.subscriptions.retrieve(subscription_id)
|
||||
user.subscription_current_period_end = _period_end_from_subscription(sub_obj)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.warning("Could not retrieve subscription %s for period_end", subscription_id)
|
||||
user.subscription_current_period_end = _period_end_from_subscription(
|
||||
sub_obj
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Could not retrieve subscription %s for period_end", subscription_id
|
||||
)
|
||||
|
||||
logger.info("Activated subscription for user %s: plan=%s subscription=%s", user.id, plan_id, subscription_id)
|
||||
logger.info(
|
||||
"Activated subscription for user %s: plan=%s subscription=%s",
|
||||
user.id,
|
||||
plan_id,
|
||||
subscription_id,
|
||||
)
|
||||
await db_session.commit()
|
||||
return StripeWebhookResponse()
|
||||
|
||||
|
|
@ -601,6 +696,47 @@ async def create_subscription_checkout(
|
|||
db_session: AsyncSession = Depends(get_async_session),
|
||||
) -> CreateSubscriptionCheckoutResponse:
|
||||
"""Create a Stripe Checkout Session for a recurring subscription."""
|
||||
# Admin-approval mode: when Stripe is not configured, queue a manual request
|
||||
if not config.STRIPE_SECRET_KEY:
|
||||
if user.subscription_status == SubscriptionStatus.ACTIVE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="You already have an active subscription.",
|
||||
)
|
||||
existing = await db_session.execute(
|
||||
select(SubscriptionRequest)
|
||||
.where(SubscriptionRequest.user_id == user.id)
|
||||
.where(SubscriptionRequest.status == SubscriptionRequestStatus.PENDING)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="You already have a pending subscription request.",
|
||||
)
|
||||
cooldown_cutoff = datetime.now(UTC) - timedelta(hours=24)
|
||||
recently_rejected = await db_session.execute(
|
||||
select(SubscriptionRequest)
|
||||
.where(SubscriptionRequest.user_id == user.id)
|
||||
.where(SubscriptionRequest.status == SubscriptionRequestStatus.REJECTED)
|
||||
.where(SubscriptionRequest.created_at >= cooldown_cutoff)
|
||||
)
|
||||
if recently_rejected.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Your previous request was rejected. Please wait 24 hours before resubmitting.",
|
||||
)
|
||||
req = SubscriptionRequest(user_id=user.id, plan_id=body.plan_id.value)
|
||||
db_session.add(req)
|
||||
await db_session.commit()
|
||||
logger.info(
|
||||
"Admin-approval subscription request created for user %s (plan=%s)",
|
||||
user.id,
|
||||
body.plan_id.value,
|
||||
)
|
||||
return CreateSubscriptionCheckoutResponse(
|
||||
checkout_url="", admin_approval_mode=True
|
||||
)
|
||||
|
||||
stripe_client = get_stripe_client()
|
||||
price_id = _get_price_id_for_plan(body.plan_id)
|
||||
success_url, cancel_url = _get_subscription_urls()
|
||||
|
|
@ -653,6 +789,7 @@ async def verify_checkout_session(
|
|||
user: User = Depends(current_active_user),
|
||||
) -> dict:
|
||||
"""Verify a Stripe Checkout Session belongs to the user and is paid."""
|
||||
_check_verify_session_rate_limit(str(user.id))
|
||||
stripe_client = get_stripe_client()
|
||||
try:
|
||||
session = stripe_client.v1.checkout.sessions.retrieve(session_id)
|
||||
|
|
@ -743,7 +880,9 @@ async def stripe_webhook(
|
|||
return StripeWebhookResponse()
|
||||
|
||||
if session_mode == "subscription":
|
||||
return await _activate_subscription_from_checkout(db_session, checkout_session)
|
||||
return await _activate_subscription_from_checkout(
|
||||
db_session, checkout_session
|
||||
)
|
||||
|
||||
return await _fulfill_completed_purchase(db_session, checkout_session)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,11 +15,12 @@ from app.db import (
|
|||
from app.schemas import (
|
||||
GlobalVisionLLMConfigRead,
|
||||
VisionLLMConfigCreate,
|
||||
VisionLLMConfigPublic,
|
||||
VisionLLMConfigRead,
|
||||
VisionLLMConfigUpdate,
|
||||
)
|
||||
from app.services.vision_model_list_service import get_vision_model_list
|
||||
from app.users import current_active_user
|
||||
from app.users import current_active_user, current_superuser
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -118,18 +119,11 @@ async def get_global_vision_llm_configs(
|
|||
async def create_vision_llm_config(
|
||||
config_data: VisionLLMConfigCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
user: User = Depends(current_superuser),
|
||||
):
|
||||
"""Create a new vision LLM config. Superuser only."""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
config_data.search_space_id,
|
||||
Permission.VISION_CONFIGS_CREATE.value,
|
||||
"You don't have permission to create vision LLM configs in this search space",
|
||||
)
|
||||
|
||||
db_config = VisionLLMConfig(**config_data.model_dump(), user_id=user.id)
|
||||
db_config = VisionLLMConfig(**config_data.model_dump(), user_id=None)
|
||||
session.add(db_config)
|
||||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
|
|
@ -145,7 +139,7 @@ async def create_vision_llm_config(
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/vision-llm-configs", response_model=list[VisionLLMConfigRead])
|
||||
@router.get("/vision-llm-configs", response_model=list[VisionLLMConfigPublic])
|
||||
async def list_vision_llm_configs(
|
||||
search_space_id: int,
|
||||
skip: int = 0,
|
||||
|
|
@ -164,7 +158,10 @@ async def list_vision_llm_configs(
|
|||
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig)
|
||||
.filter(VisionLLMConfig.search_space_id == search_space_id)
|
||||
.filter(
|
||||
(VisionLLMConfig.search_space_id == search_space_id)
|
||||
| (VisionLLMConfig.search_space_id == None) # noqa: E711
|
||||
)
|
||||
.order_by(VisionLLMConfig.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
|
|
@ -217,24 +214,19 @@ async def update_vision_llm_config(
|
|||
config_id: int,
|
||||
update_data: VisionLLMConfigUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
user: User = Depends(current_superuser),
|
||||
):
|
||||
"""Update an existing vision LLM config. Superuser only."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id)
|
||||
select(VisionLLMConfig).filter(
|
||||
VisionLLMConfig.id == config_id, VisionLLMConfig.user_id.is_(None)
|
||||
)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.VISION_CONFIGS_CREATE.value,
|
||||
"You don't have permission to update vision LLM configs in this search space",
|
||||
)
|
||||
|
||||
for key, value in update_data.model_dump(exclude_unset=True).items():
|
||||
setattr(db_config, key, value)
|
||||
|
||||
|
|
@ -256,24 +248,19 @@ async def update_vision_llm_config(
|
|||
async def delete_vision_llm_config(
|
||||
config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
user: User = Depends(current_superuser),
|
||||
):
|
||||
"""Delete a vision LLM config. Superuser only."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id)
|
||||
select(VisionLLMConfig).filter(
|
||||
VisionLLMConfig.id == config_id, VisionLLMConfig.user_id.is_(None)
|
||||
)
|
||||
)
|
||||
db_config = result.scalars().first()
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.VISION_CONFIGS_DELETE.value,
|
||||
"You don't have permission to delete vision LLM configs in this search space",
|
||||
)
|
||||
|
||||
await session.delete(db_config)
|
||||
await session.commit()
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -55,8 +55,9 @@ class ImageGenerationConfigBase(BaseModel):
|
|||
class ImageGenerationConfigCreate(ImageGenerationConfigBase):
|
||||
"""Schema for creating a new ImageGenerationConfig."""
|
||||
|
||||
search_space_id: int = Field(
|
||||
..., description="Search space ID to associate the config with"
|
||||
search_space_id: int | None = Field(
|
||||
None,
|
||||
description="Search space ID. None = global admin config visible to all spaces",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -79,8 +80,8 @@ class ImageGenerationConfigRead(ImageGenerationConfigBase):
|
|||
|
||||
id: int
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
search_space_id: int | None = None
|
||||
user_id: uuid.UUID | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
@ -98,8 +99,8 @@ class ImageGenerationConfigPublic(BaseModel):
|
|||
api_version: str | None = None
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
search_space_id: int | None = None
|
||||
user_id: uuid.UUID | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -60,8 +60,9 @@ class NewLLMConfigBase(BaseModel):
|
|||
class NewLLMConfigCreate(NewLLMConfigBase):
|
||||
"""Schema for creating a new NewLLMConfig."""
|
||||
|
||||
search_space_id: int = Field(
|
||||
..., description="Search space ID to associate the config with"
|
||||
search_space_id: int | None = Field(
|
||||
None,
|
||||
description="Search space ID. None = global admin config visible to all spaces",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -90,8 +91,8 @@ class NewLLMConfigRead(NewLLMConfigBase):
|
|||
|
||||
id: int
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
search_space_id: int | None = None
|
||||
user_id: uuid.UUID | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
@ -119,8 +120,8 @@ class NewLLMConfigPublic(BaseModel):
|
|||
citations_enabled: bool
|
||||
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
search_space_id: int | None = None
|
||||
user_id: uuid.UUID | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ class CreateSubscriptionCheckoutResponse(BaseModel):
|
|||
"""Response containing the Stripe-hosted subscription checkout URL."""
|
||||
|
||||
checkout_url: str
|
||||
admin_approval_mode: bool = False
|
||||
|
||||
|
||||
class CreateCheckoutSessionResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -6,6 +6,10 @@ from fastapi_users import schemas
|
|||
class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
pages_limit: int
|
||||
pages_used: int
|
||||
monthly_token_limit: int
|
||||
tokens_used_this_month: int
|
||||
plan_id: str
|
||||
subscription_status: str
|
||||
display_name: str | None = None
|
||||
avatar_url: str | None = None
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,10 @@ class VisionLLMConfigBase(BaseModel):
|
|||
|
||||
|
||||
class VisionLLMConfigCreate(VisionLLMConfigBase):
|
||||
search_space_id: int = Field(...)
|
||||
search_space_id: int | None = Field(
|
||||
None,
|
||||
description="Search space ID. None = global admin config visible to all spaces",
|
||||
)
|
||||
|
||||
|
||||
class VisionLLMConfigUpdate(BaseModel):
|
||||
|
|
@ -38,8 +41,8 @@ class VisionLLMConfigUpdate(BaseModel):
|
|||
class VisionLLMConfigRead(VisionLLMConfigBase):
|
||||
id: int
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
search_space_id: int | None = None
|
||||
user_id: uuid.UUID | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
@ -55,8 +58,8 @@ class VisionLLMConfigPublic(BaseModel):
|
|||
api_version: str | None = None
|
||||
litellm_params: dict[str, Any] | None = None
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
search_space_id: int | None = None
|
||||
user_id: uuid.UUID | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -51,16 +51,25 @@ class PageLimitService:
|
|||
"""
|
||||
from app.db import User
|
||||
|
||||
# Get user's current page usage
|
||||
# Get user's current page usage and subscription status
|
||||
result = await self.session.execute(
|
||||
select(User.pages_used, User.pages_limit).where(User.id == user_id)
|
||||
select(User.pages_used, User.pages_limit, User.subscription_status).where(
|
||||
User.id == user_id
|
||||
)
|
||||
)
|
||||
row = result.first()
|
||||
|
||||
if not row:
|
||||
raise ValueError(f"User with ID {user_id} not found")
|
||||
|
||||
pages_used, pages_limit = row
|
||||
pages_used, pages_limit, sub_status = row
|
||||
|
||||
# PAST_DUE: enforce free-tier page limit to prevent usage without payment
|
||||
if str(sub_status).lower() == "past_due":
|
||||
from app.config import config as app_config # avoid circular import
|
||||
|
||||
free_limit = app_config.PLAN_LIMITS.get("free", {}).get("pages_limit", 500)
|
||||
pages_limit = min(pages_limit, free_limit)
|
||||
|
||||
# Check if adding estimated pages would exceed limit
|
||||
if pages_used + estimated_pages > pages_limit:
|
||||
|
|
|
|||
|
|
@ -97,6 +97,15 @@ class TokenQuotaService:
|
|||
tokens_used = user.tokens_used_this_month or 0
|
||||
token_limit = user.monthly_token_limit or 0
|
||||
|
||||
# PAST_DUE: enforce free-tier token limit to prevent usage without payment
|
||||
if str(getattr(user, "subscription_status", "")).lower() == "past_due":
|
||||
from app.config import config as app_config # avoid circular import
|
||||
|
||||
free_limit = app_config.PLAN_LIMITS.get("free", {}).get(
|
||||
"monthly_token_limit", 50000
|
||||
)
|
||||
token_limit = min(token_limit, free_limit)
|
||||
|
||||
# Strict boundary: >= means at-limit is also exceeded
|
||||
if tokens_used + estimated_tokens >= token_limit and token_limit > 0:
|
||||
raise TokenQuotaExceededError(
|
||||
|
|
|
|||
|
|
@ -300,3 +300,4 @@ fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
|||
|
||||
current_active_user = fastapi_users.current_user(active=True)
|
||||
current_optional_user = fastapi_users.current_user(active=True, optional=True)
|
||||
current_superuser = fastapi_users.current_user(active=True, superuser=True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue