diff --git a/internal/infrastructure/auth/tenants.go b/internal/infrastructure/auth/tenants.go index 8dce5b8..2a477e0 100644 --- a/internal/infrastructure/auth/tenants.go +++ b/internal/infrastructure/auth/tenants.go @@ -124,20 +124,18 @@ func (s *TenantStore) migrate() error { stripe_customer_id TEXT DEFAULT '', stripe_sub_id TEXT DEFAULT '', owner_user_id TEXT NOT NULL, - active INTEGER NOT NULL DEFAULT 1, - created_at TEXT NOT NULL, + active BOOLEAN NOT NULL DEFAULT true, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), events_this_month INTEGER NOT NULL DEFAULT 0, - month_reset_at TEXT NOT NULL + month_reset_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); - -- Add tenant_id to users table if not exists - -- SQLite doesn't support ADD COLUMN IF NOT EXISTS, so we use a trick `) if err != nil { return err } // Add tenant_id column to users if missing - _, _ = s.db.Exec(`ALTER TABLE users ADD COLUMN tenant_id TEXT DEFAULT ''`) + _, _ = s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS tenant_id TEXT DEFAULT ''`) return nil } @@ -154,13 +152,11 @@ func (s *TenantStore) loadFromDB() { defer s.mu.Unlock() for rows.Next() { var t Tenant - var createdAt, monthReset string if err := rows.Scan(&t.ID, &t.Name, &t.Slug, &t.PlanID, &t.PaymentCustomerID, - &t.PaymentSubID, &t.OwnerUserID, &t.Active, &createdAt, &t.EventsThisMonth, &monthReset); err != nil { + &t.PaymentSubID, &t.OwnerUserID, &t.Active, &t.CreatedAt, &t.EventsThisMonth, &t.MonthResetAt); err != nil { + slog.Warn("load tenant row scan", "error", err) continue } - t.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) - t.MonthResetAt, _ = time.Parse(time.RFC3339, monthReset) s.tenants[t.ID] = &t } slog.Info("tenants loaded from DB", "count", len(s.tenants)) @@ -171,12 +167,21 @@ func (s *TenantStore) persistTenant(t *Tenant) { return } _, err := s.db.Exec(` - INSERT OR REPLACE INTO tenants (id, name, slug, plan_id, stripe_customer_id, stripe_sub_id, + INSERT INTO tenants (id, name, slug, plan_id, stripe_customer_id, stripe_sub_id, owner_user_id, active, created_at, events_this_month, month_reset_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + ON CONFLICT (id) DO UPDATE SET + name = EXCLUDED.name, + slug = EXCLUDED.slug, + plan_id = EXCLUDED.plan_id, + stripe_customer_id = EXCLUDED.stripe_customer_id, + stripe_sub_id = EXCLUDED.stripe_sub_id, + active = EXCLUDED.active, + events_this_month = EXCLUDED.events_this_month, + month_reset_at = EXCLUDED.month_reset_at`, t.ID, t.Name, t.Slug, t.PlanID, t.PaymentCustomerID, t.PaymentSubID, - t.OwnerUserID, t.Active, t.CreatedAt.Format(time.RFC3339), - t.EventsThisMonth, t.MonthResetAt.Format(time.RFC3339), + t.OwnerUserID, t.Active, t.CreatedAt, + t.EventsThisMonth, t.MonthResetAt, ) if err != nil { slog.Error("persist tenant", "id", t.ID, "error", err) diff --git a/internal/infrastructure/auth/users.go b/internal/infrastructure/auth/users.go index 39b7c77..44ca00f 100644 --- a/internal/infrastructure/auth/users.go +++ b/internal/infrastructure/auth/users.go @@ -95,13 +95,13 @@ func (s *UserStore) migrate() error { email TEXT UNIQUE NOT NULL, display_name TEXT NOT NULL DEFAULT '', role TEXT NOT NULL DEFAULT 'viewer', - active INTEGER NOT NULL DEFAULT 1, - email_verified INTEGER NOT NULL DEFAULT 0, + active BOOLEAN NOT NULL DEFAULT true, + email_verified BOOLEAN NOT NULL DEFAULT false, password_hash TEXT NOT NULL, verify_token TEXT DEFAULT '', - verify_expiry TEXT DEFAULT '', - created_at TEXT NOT NULL, - last_login_at TEXT + verify_expiry TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + last_login_at TIMESTAMPTZ ); CREATE TABLE IF NOT EXISTS api_keys ( id TEXT PRIMARY KEY, @@ -109,21 +109,22 @@ func (s *UserStore) migrate() error { key_hash TEXT NOT NULL, name TEXT NOT NULL DEFAULT '', role TEXT NOT NULL DEFAULT 'viewer', - created_at TEXT NOT NULL, - last_used TEXT + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + last_used TIMESTAMPTZ ); `) if err != nil { return err } - // Add columns if upgrading from older schema - s.db.Exec(`ALTER TABLE users ADD COLUMN email_verified INTEGER NOT NULL DEFAULT 0`) - s.db.Exec(`ALTER TABLE users ADD COLUMN verify_token TEXT DEFAULT ''`) - s.db.Exec(`ALTER TABLE users ADD COLUMN verify_expiry TEXT DEFAULT ''`) + // Add columns if upgrading from older schema (ignore errors if column exists) + s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS display_name TEXT NOT NULL DEFAULT ''`) + s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS email_verified BOOLEAN NOT NULL DEFAULT false`) + s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS verify_token TEXT DEFAULT ''`) + s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS verify_expiry TIMESTAMPTZ`) return nil } -// loadFromDB loads all users from SQLite into memory cache. +// loadFromDB loads all users from DB into memory cache. func (s *UserStore) loadFromDB() { rows, err := s.db.Query(`SELECT id, email, display_name, role, active, password_hash, created_at, last_login_at FROM users`) if err != nil { @@ -136,43 +137,38 @@ func (s *UserStore) loadFromDB() { defer s.mu.Unlock() for rows.Next() { var u User - var createdAt string - var lastLogin sql.NullString - if err := rows.Scan(&u.ID, &u.Email, &u.DisplayName, &u.Role, &u.Active, &u.PasswordHash, &createdAt, &lastLogin); err != nil { + var lastLogin sql.NullTime + if err := rows.Scan(&u.ID, &u.Email, &u.DisplayName, &u.Role, &u.Active, &u.PasswordHash, &u.CreatedAt, &lastLogin); err != nil { + slog.Warn("load user row scan", "error", err) continue } - u.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) if lastLogin.Valid { - t, _ := time.Parse(time.RFC3339, lastLogin.String) - u.LastLoginAt = &t + u.LastLoginAt = &lastLogin.Time } s.users[u.Email] = &u } slog.Info("users loaded from DB", "count", len(s.users)) } -// persistUser writes a user to SQLite. +// persistUser writes a user to DB (PostgreSQL-compatible upsert). func (s *UserStore) persistUser(u *User) { if s.db == nil { return } - var lastLogin *string - if u.LastLoginAt != nil { - t := u.LastLoginAt.Format(time.RFC3339) - lastLogin = &t - } - var verifyExpiry string - if u.VerifyExpiry != nil { - verifyExpiry = u.VerifyExpiry.Format(time.RFC3339) - } - verified := 0 - if u.EmailVerified { - verified = 1 - } _, err := s.db.Exec(` - INSERT OR REPLACE INTO users (id, email, display_name, role, active, email_verified, password_hash, verify_token, verify_expiry, created_at, last_login_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, - u.ID, u.Email, u.DisplayName, u.Role, u.Active, verified, u.PasswordHash, u.VerifyToken, verifyExpiry, u.CreatedAt.Format(time.RFC3339), lastLogin, + INSERT INTO users (id, email, display_name, role, active, email_verified, password_hash, verify_token, verify_expiry, created_at, last_login_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + ON CONFLICT (id) DO UPDATE SET + email = EXCLUDED.email, + display_name = EXCLUDED.display_name, + role = EXCLUDED.role, + active = EXCLUDED.active, + email_verified = EXCLUDED.email_verified, + password_hash = EXCLUDED.password_hash, + verify_token = EXCLUDED.verify_token, + verify_expiry = EXCLUDED.verify_expiry, + last_login_at = EXCLUDED.last_login_at`, + u.ID, u.Email, u.DisplayName, u.Role, u.Active, u.EmailVerified, u.PasswordHash, u.VerifyToken, u.VerifyExpiry, u.CreatedAt, u.LastLoginAt, ) if err != nil { slog.Error("persist user", "email", u.Email, "error", err) @@ -366,7 +362,7 @@ func (s *UserStore) DeleteUser(id string) error { if u.ID == id { delete(s.users, email) if s.db != nil { - go s.db.Exec(`DELETE FROM users WHERE id = ?`, id) + go s.db.Exec(`DELETE FROM users WHERE id = $1`, id) } return nil } @@ -406,8 +402,8 @@ func (s *UserStore) CreateAPIKey(userID, name, role string) (string, *APIKey, er } if s.db != nil { - _, err := s.db.Exec(`INSERT INTO api_keys (id, user_id, key_hash, name, role, created_at) VALUES (?,?,?,?,?,?)`, - ak.ID, ak.UserID, keyHash, ak.Name, ak.Role, ak.CreatedAt.Format(time.RFC3339)) + _, err := s.db.Exec(`INSERT INTO api_keys (id, user_id, key_hash, name, role, created_at) VALUES ($1,$2,$3,$4,$5,$6)`, + ak.ID, ak.UserID, keyHash, ak.Name, ak.Role, ak.CreatedAt) if err != nil { return "", nil, err } @@ -423,13 +419,13 @@ func (s *UserStore) ValidateAPIKey(key string) (string, string, error) { } keyHash := hashKey(key) var userID, role string - err := s.db.QueryRow(`SELECT user_id, role FROM api_keys WHERE key_hash = ?`, keyHash).Scan(&userID, &role) + err := s.db.QueryRow(`SELECT user_id, role FROM api_keys WHERE key_hash = $1`, keyHash).Scan(&userID, &role) if err != nil { return "", "", ErrInvalidToken } // Update last_used - go s.db.Exec(`UPDATE api_keys SET last_used = ? WHERE key_hash = ?`, time.Now().Format(time.RFC3339), keyHash) + go s.db.Exec(`UPDATE api_keys SET last_used = $1 WHERE key_hash = $2`, time.Now(), keyHash) return userID, role, nil } @@ -438,7 +434,7 @@ func (s *UserStore) ListAPIKeys(userID string) ([]APIKey, error) { if s.db == nil { return nil, nil } - rows, err := s.db.Query(`SELECT id, user_id, name, role, created_at, last_used FROM api_keys WHERE user_id = ?`, userID) + rows, err := s.db.Query(`SELECT id, user_id, name, role, created_at, last_used FROM api_keys WHERE user_id = $1`, userID) if err != nil { return nil, err } @@ -447,15 +443,12 @@ func (s *UserStore) ListAPIKeys(userID string) ([]APIKey, error) { var keys []APIKey for rows.Next() { var ak APIKey - var createdAt string - var lastUsed sql.NullString - if err := rows.Scan(&ak.ID, &ak.UserID, &ak.Name, &ak.Role, &createdAt, &lastUsed); err != nil { + var lastUsed sql.NullTime + if err := rows.Scan(&ak.ID, &ak.UserID, &ak.Name, &ak.Role, &ak.CreatedAt, &lastUsed); err != nil { continue } - ak.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) if lastUsed.Valid { - t, _ := time.Parse(time.RFC3339, lastUsed.String) - ak.LastUsed = &t + ak.LastUsed = &lastUsed.Time } keys = append(keys, ak) } @@ -467,7 +460,7 @@ func (s *UserStore) DeleteAPIKey(keyID, userID string) error { if s.db == nil { return nil } - _, err := s.db.Exec(`DELETE FROM api_keys WHERE id = ? AND user_id = ?`, keyID, userID) + _, err := s.db.Exec(`DELETE FROM api_keys WHERE id = $1 AND user_id = $2`, keyID, userID) return err }