Release prep: 54 engines, self-hosted signatures, i18n, dashboard updates

This commit is contained in:
DmitrL-dev 2026-03-23 16:45:40 +10:00
parent 694e32be26
commit 41cbfd6e0a
178 changed files with 36008 additions and 399 deletions

View file

@ -0,0 +1,367 @@
package auth
import (
"encoding/json"
"net/http"
"strings"
)
// LoginRequest is the POST /api/auth/login body.
type LoginRequest struct {
Email string `json:"email"`
Password string `json:"password"`
}
// TokenResponse is returned on successful login/refresh.
type TokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"` // seconds
TokenType string `json:"token_type"`
User *User `json:"user"`
}
// HandleLogin creates an HTTP handler for POST /api/auth/login.
func HandleLogin(store *UserStore, secret []byte) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req LoginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeAuthError(w, http.StatusBadRequest, "invalid JSON body")
return
}
// Support both "email" and legacy "username" field
email := req.Email
if email == "" {
// Try legacy format
var legacy struct{ Username string `json:"username"` }
email = legacy.Username
}
user, err := store.Authenticate(email, req.Password)
if err != nil {
if err == ErrEmailNotVerified {
writeAuthError(w, http.StatusForbidden, "email not verified — check your inbox for the verification code")
return
}
writeAuthError(w, http.StatusUnauthorized, "invalid credentials")
return
}
accessToken, err := NewAccessToken(user.Email, user.Role, secret, 0)
if err != nil {
writeAuthError(w, http.StatusInternalServerError, "token generation failed")
return
}
refreshToken, err := NewRefreshToken(user.Email, user.Role, secret, 0)
if err != nil {
writeAuthError(w, http.StatusInternalServerError, "token generation failed")
return
}
resp := TokenResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresIn: 900, // 15 minutes
TokenType: "Bearer",
User: user,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
}
// HandleRefresh creates an HTTP handler for POST /api/auth/refresh.
func HandleRefresh(secret []byte) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req struct {
RefreshToken string `json:"refresh_token"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeAuthError(w, http.StatusBadRequest, "invalid JSON body")
return
}
claims, err := Verify(req.RefreshToken, secret)
if err != nil {
writeAuthError(w, http.StatusUnauthorized, "invalid or expired refresh token")
return
}
accessToken, err := NewAccessToken(claims.Sub, claims.Role, secret, 0)
if err != nil {
writeAuthError(w, http.StatusInternalServerError, "token generation failed")
return
}
resp := TokenResponse{
AccessToken: accessToken,
RefreshToken: req.RefreshToken,
ExpiresIn: 900,
TokenType: "Bearer",
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
}
// HandleMe returns the current authenticated user profile.
// GET /api/auth/me
func HandleMe(store *UserStore) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
claims := GetClaims(r.Context())
if claims == nil {
writeAuthError(w, http.StatusUnauthorized, "not authenticated")
return
}
user, err := store.GetByEmail(claims.Sub)
if err != nil {
writeAuthError(w, http.StatusNotFound, "user not found")
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(user)
}
}
// HandleListUsers returns all users (admin only).
// GET /api/auth/users
func HandleListUsers(store *UserStore) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
users := store.ListUsers()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"users": users,
"total": len(users),
})
}
}
// HandleCreateUser creates a new user (admin only).
// POST /api/auth/users
func HandleCreateUser(store *UserStore) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req struct {
Email string `json:"email"`
DisplayName string `json:"display_name"`
Password string `json:"password"`
Role string `json:"role"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeAuthError(w, http.StatusBadRequest, "invalid JSON")
return
}
if req.Email == "" || req.Password == "" {
writeAuthError(w, http.StatusBadRequest, "email and password required")
return
}
if req.Role == "" {
req.Role = "viewer"
}
// Validate role
validRoles := map[string]bool{"admin": true, "analyst": true, "viewer": true}
if !validRoles[req.Role] {
writeAuthError(w, http.StatusBadRequest, "invalid role (valid: admin, analyst, viewer)")
return
}
user, err := store.CreateUser(req.Email, req.DisplayName, req.Password, req.Role)
if err != nil {
if err == ErrUserExists {
writeAuthError(w, http.StatusConflict, "user already exists")
} else {
writeAuthError(w, http.StatusInternalServerError, err.Error())
}
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(user)
}
}
// HandleUpdateUser updates a user's profile (admin only).
// PUT /api/auth/users/{id}
func HandleUpdateUser(store *UserStore) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
if id == "" {
writeAuthError(w, http.StatusBadRequest, "user id required")
return
}
var req struct {
DisplayName string `json:"display_name"`
Role string `json:"role"`
Active *bool `json:"active"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeAuthError(w, http.StatusBadRequest, "invalid JSON")
return
}
active := true
if req.Active != nil {
active = *req.Active
}
if err := store.UpdateUser(id, req.DisplayName, req.Role, active); err != nil {
writeAuthError(w, http.StatusNotFound, err.Error())
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"status": "updated"})
}
}
// HandleDeleteUser deletes a user (admin only).
// DELETE /api/auth/users/{id}
func HandleDeleteUser(store *UserStore) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
if id == "" {
writeAuthError(w, http.StatusBadRequest, "user id required")
return
}
if err := store.DeleteUser(id); err != nil {
writeAuthError(w, http.StatusNotFound, err.Error())
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"status": "deleted"})
}
}
// HandleCreateAPIKey generates a new API key for the authenticated user.
// POST /api/auth/keys
func HandleCreateAPIKey(store *UserStore) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
claims := GetClaims(r.Context())
if claims == nil {
writeAuthError(w, http.StatusUnauthorized, "not authenticated")
return
}
user, err := store.GetByEmail(claims.Sub)
if err != nil {
writeAuthError(w, http.StatusNotFound, "user not found")
return
}
var req struct {
Name string `json:"name"`
Role string `json:"role"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeAuthError(w, http.StatusBadRequest, "invalid JSON")
return
}
if req.Name == "" {
req.Name = "default"
}
if req.Role == "" {
req.Role = user.Role
}
fullKey, ak, err := store.CreateAPIKey(user.ID, req.Name, req.Role)
if err != nil {
writeAuthError(w, http.StatusInternalServerError, err.Error())
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]any{
"key": fullKey, // shown only once
"details": ak,
})
}
}
// HandleListAPIKeys returns API keys for the authenticated user.
// GET /api/auth/keys
func HandleListAPIKeys(store *UserStore) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
claims := GetClaims(r.Context())
if claims == nil {
writeAuthError(w, http.StatusUnauthorized, "not authenticated")
return
}
user, err := store.GetByEmail(claims.Sub)
if err != nil {
writeAuthError(w, http.StatusNotFound, "user not found")
return
}
keys, err := store.ListAPIKeys(user.ID)
if err != nil {
writeAuthError(w, http.StatusInternalServerError, err.Error())
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{"keys": keys})
}
}
// HandleDeleteAPIKey revokes an API key.
// DELETE /api/auth/keys/{id}
func HandleDeleteAPIKey(store *UserStore) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
claims := GetClaims(r.Context())
if claims == nil {
writeAuthError(w, http.StatusUnauthorized, "not authenticated")
return
}
user, err := store.GetByEmail(claims.Sub)
if err != nil {
writeAuthError(w, http.StatusNotFound, "user not found")
return
}
keyID := r.PathValue("id")
if err := store.DeleteAPIKey(keyID, user.ID); err != nil {
writeAuthError(w, http.StatusInternalServerError, err.Error())
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"status": "revoked"})
}
}
// APIKeyMiddleware checks for API key authentication alongside JWT.
// If Authorization header starts with "stx_", validate as API key.
func APIKeyMiddleware(store *UserStore, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if strings.HasPrefix(authHeader, "Bearer stx_") {
key := strings.TrimPrefix(authHeader, "Bearer ")
_, role, err := store.ValidateAPIKey(key)
if err != nil {
writeAuthError(w, http.StatusUnauthorized, "invalid API key")
return
}
// Inject synthetic claims for RBAC compatibility
claims := &Claims{Sub: "api-key", Role: role}
ctx := SetClaimsContext(r.Context(), claims)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
next.ServeHTTP(w, r)
})
}

View file

@ -0,0 +1,136 @@
// Package auth provides JWT authentication for the SOC HTTP API.
// Uses HMAC-SHA256 (HS256) with configurable secret.
// Zero external dependencies — pure Go stdlib.
package auth
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
)
// Standard JWT errors.
var (
ErrInvalidToken = errors.New("auth: invalid token")
ErrExpiredToken = errors.New("auth: token expired")
ErrInvalidSecret = errors.New("auth: secret too short (min 32 bytes)")
)
// Claims represents JWT payload.
type Claims struct {
Sub string `json:"sub"` // Subject (username or user ID)
Role string `json:"role"` // RBAC role: admin, operator, analyst, viewer
TenantID string `json:"tenant_id,omitempty"` // Multi-tenant isolation
Exp int64 `json:"exp"` // Expiration (Unix timestamp)
Iat int64 `json:"iat"` // Issued at
Iss string `json:"iss,omitempty"` // Issuer
}
// IsExpired returns true if the token has expired.
func (c Claims) IsExpired() bool {
return time.Now().Unix() > c.Exp
}
// header is the JWT header (always HS256).
var jwtHeader = base64URLEncode([]byte(`{"alg":"HS256","typ":"JWT"}`))
// Sign creates a JWT token string from claims.
func Sign(claims Claims, secret []byte) (string, error) {
if len(secret) < 32 {
return "", ErrInvalidSecret
}
if claims.Iat == 0 {
claims.Iat = time.Now().Unix()
}
if claims.Iss == "" {
claims.Iss = "sentinel-soc"
}
payload, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("auth: marshal claims: %w", err)
}
encodedPayload := base64URLEncode(payload)
signingInput := jwtHeader + "." + encodedPayload
signature := hmacSign([]byte(signingInput), secret)
return signingInput + "." + signature, nil
}
// Verify validates a JWT token string and returns the claims.
func Verify(tokenStr string, secret []byte) (*Claims, error) {
parts := strings.SplitN(tokenStr, ".", 3)
if len(parts) != 3 {
return nil, ErrInvalidToken
}
signingInput := parts[0] + "." + parts[1]
expectedSig := hmacSign([]byte(signingInput), secret)
if !hmac.Equal([]byte(parts[2]), []byte(expectedSig)) {
return nil, ErrInvalidToken
}
payload, err := base64URLDecode(parts[1])
if err != nil {
return nil, fmt.Errorf("%w: bad payload encoding", ErrInvalidToken)
}
var claims Claims
if err := json.Unmarshal(payload, &claims); err != nil {
return nil, fmt.Errorf("%w: bad payload JSON", ErrInvalidToken)
}
if claims.IsExpired() {
return nil, ErrExpiredToken
}
return &claims, nil
}
// NewAccessToken creates a short-lived access token (15 min default).
func NewAccessToken(subject, role string, secret []byte, ttl time.Duration) (string, error) {
if ttl == 0 {
ttl = 15 * time.Minute
}
return Sign(Claims{
Sub: subject,
Role: role,
Exp: time.Now().Add(ttl).Unix(),
}, secret)
}
// NewRefreshToken creates a long-lived refresh token (7 days default).
func NewRefreshToken(subject, role string, secret []byte, ttl time.Duration) (string, error) {
if ttl == 0 {
ttl = 7 * 24 * time.Hour
}
return Sign(Claims{
Sub: subject,
Role: role,
Exp: time.Now().Add(ttl).Unix(),
}, secret)
}
// --- base64url helpers (RFC 7515) ---
func base64URLEncode(data []byte) string {
return base64.RawURLEncoding.EncodeToString(data)
}
func base64URLDecode(s string) ([]byte, error) {
return base64.RawURLEncoding.DecodeString(s)
}
func hmacSign(data, secret []byte) string {
mac := hmac.New(sha256.New, secret)
mac.Write(data)
return base64URLEncode(mac.Sum(nil))
}

View file

@ -0,0 +1,115 @@
package auth
import (
"testing"
"time"
)
var testSecret = []byte("test-secret-must-be-at-least-32-bytes-long!")
func TestSign_Verify_RoundTrip(t *testing.T) {
claims := Claims{
Sub: "admin",
Role: "admin",
Exp: time.Now().Add(time.Hour).Unix(),
}
token, err := Sign(claims, testSecret)
if err != nil {
t.Fatalf("Sign: %v", err)
}
got, err := Verify(token, testSecret)
if err != nil {
t.Fatalf("Verify: %v", err)
}
if got.Sub != "admin" {
t.Errorf("Sub = %q, want admin", got.Sub)
}
if got.Role != "admin" {
t.Errorf("Role = %q, want admin", got.Role)
}
if got.Iss != "sentinel-soc" {
t.Errorf("Iss = %q, want sentinel-soc", got.Iss)
}
}
func TestVerify_ExpiredToken(t *testing.T) {
token, _ := Sign(Claims{
Sub: "user",
Role: "viewer",
Exp: time.Now().Add(-time.Hour).Unix(),
}, testSecret)
_, err := Verify(token, testSecret)
if err != ErrExpiredToken {
t.Errorf("expected ErrExpiredToken, got %v", err)
}
}
func TestVerify_InvalidSignature(t *testing.T) {
token, _ := Sign(Claims{
Sub: "user",
Role: "viewer",
Exp: time.Now().Add(time.Hour).Unix(),
}, testSecret)
wrongSecret := []byte("wrong-secret-that-is-also-32-bytes-x")
_, err := Verify(token, wrongSecret)
if err != ErrInvalidToken {
t.Errorf("expected ErrInvalidToken, got %v", err)
}
}
func TestVerify_MalformedToken(t *testing.T) {
_, err := Verify("not.a.valid.jwt", testSecret)
if err != ErrInvalidToken {
t.Errorf("expected ErrInvalidToken, got %v", err)
}
_, err = Verify("", testSecret)
if err != ErrInvalidToken {
t.Errorf("expected ErrInvalidToken for empty token, got %v", err)
}
}
func TestSign_ShortSecret(t *testing.T) {
_, err := Sign(Claims{Sub: "x", Exp: time.Now().Add(time.Hour).Unix()}, []byte("short"))
if err != ErrInvalidSecret {
t.Errorf("expected ErrInvalidSecret, got %v", err)
}
}
func TestNewAccessToken(t *testing.T) {
token, err := NewAccessToken("analyst", "analyst", testSecret, 0)
if err != nil {
t.Fatalf("NewAccessToken: %v", err)
}
claims, err := Verify(token, testSecret)
if err != nil {
t.Fatalf("Verify: %v", err)
}
if claims.Sub != "analyst" || claims.Role != "analyst" {
t.Errorf("unexpected claims: %+v", claims)
}
// Default TTL = 15 min, check expiry is within 16 min
if claims.Exp > time.Now().Add(16*time.Minute).Unix() {
t.Error("access token TTL too long")
}
}
func TestNewRefreshToken(t *testing.T) {
token, err := NewRefreshToken("admin", "admin", testSecret, 0)
if err != nil {
t.Fatalf("NewRefreshToken: %v", err)
}
claims, err := Verify(token, testSecret)
if err != nil {
t.Fatalf("Verify: %v", err)
}
// Default TTL = 7 days
if claims.Exp < time.Now().Add(6*24*time.Hour).Unix() {
t.Error("refresh token TTL too short")
}
}

View file

@ -0,0 +1,97 @@
package auth
import (
"context"
"log/slog"
"net/http"
"strings"
)
type ctxKey string
const claimsKey ctxKey = "jwt_claims"
// JWTMiddleware validates Bearer tokens on protected routes.
type JWTMiddleware struct {
secret []byte
// PublicPaths are exempt from auth (e.g., /health, /api/auth/login).
PublicPaths map[string]bool
}
// NewJWTMiddleware creates JWT middleware with the given secret.
func NewJWTMiddleware(secret []byte) *JWTMiddleware {
return &JWTMiddleware{
secret: secret,
PublicPaths: map[string]bool{
"/health": true,
"/api/auth/login": true,
"/api/auth/refresh": true,
"/api/soc/events/stream": true, // SSE uses query param auth
"/api/soc/stream": true, // SSE live feed (EventSource can't send headers)
"/api/soc/ws": true, // WebSocket-style SSE push
},
}
}
// Middleware wraps an http.Handler with JWT validation.
func (m *JWTMiddleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip auth for public paths.
if m.PublicPaths[r.URL.Path] {
next.ServeHTTP(w, r)
return
}
// Extract Bearer token.
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
writeAuthError(w, http.StatusUnauthorized, "missing Authorization header")
return
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") {
writeAuthError(w, http.StatusUnauthorized, "invalid Authorization format (expected: Bearer <token>)")
return
}
claims, err := Verify(parts[1], m.secret)
if err != nil {
slog.Warn("JWT auth failed",
"error", err,
"path", r.URL.Path,
"remote", r.RemoteAddr,
)
if err == ErrExpiredToken {
writeAuthError(w, http.StatusUnauthorized, "token expired")
} else {
writeAuthError(w, http.StatusUnauthorized, "invalid token")
}
return
}
// Inject claims into context for downstream handlers.
ctx := context.WithValue(r.Context(), claimsKey, claims)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// GetClaims extracts JWT claims from request context.
func GetClaims(ctx context.Context) *Claims {
if c, ok := ctx.Value(claimsKey).(*Claims); ok {
return c
}
return nil
}
// SetClaimsContext injects claims into a context (used by API key auth).
func SetClaimsContext(ctx context.Context, claims *Claims) context.Context {
return context.WithValue(ctx, claimsKey, claims)
}
func writeAuthError(w http.ResponseWriter, status int, msg string) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("WWW-Authenticate", `Bearer realm="sentinel-soc"`)
w.WriteHeader(status)
w.Write([]byte(`{"error":"` + msg + `"}`))
}

View file

@ -0,0 +1,119 @@
package auth
import (
"net/http"
"sync"
"time"
)
// RateLimiter tracks login attempts per IP using a sliding window.
type RateLimiter struct {
mu sync.Mutex
attempts map[string]*ipBucket
maxHits int
window time.Duration
cleanup time.Duration
}
type ipBucket struct {
timestamps []time.Time
}
// NewRateLimiter creates a rate limiter.
// maxHits: max attempts per window per IP.
// window: sliding window duration.
func NewRateLimiter(maxHits int, window time.Duration) *RateLimiter {
rl := &RateLimiter{
attempts: make(map[string]*ipBucket),
maxHits: maxHits,
window: window,
cleanup: 5 * time.Minute,
}
go rl.cleanupLoop()
return rl
}
// Allow checks if the IP is within the rate limit.
// Returns true if allowed, false if rate-limited.
func (rl *RateLimiter) Allow(ip string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
bucket, ok := rl.attempts[ip]
if !ok {
bucket = &ipBucket{}
rl.attempts[ip] = bucket
}
// Prune old timestamps outside the window.
cutoff := now.Add(-rl.window)
valid := bucket.timestamps[:0]
for _, t := range bucket.timestamps {
if t.After(cutoff) {
valid = append(valid, t)
}
}
bucket.timestamps = valid
if len(bucket.timestamps) >= rl.maxHits {
return false
}
bucket.timestamps = append(bucket.timestamps, now)
return true
}
// Reset clears attempts for an IP (e.g., on successful login).
func (rl *RateLimiter) Reset(ip string) {
rl.mu.Lock()
defer rl.mu.Unlock()
delete(rl.attempts, ip)
}
func (rl *RateLimiter) cleanupLoop() {
ticker := time.NewTicker(rl.cleanup)
defer ticker.Stop()
for range ticker.C {
rl.mu.Lock()
now := time.Now()
cutoff := now.Add(-rl.window)
for ip, bucket := range rl.attempts {
valid := bucket.timestamps[:0]
for _, t := range bucket.timestamps {
if t.After(cutoff) {
valid = append(valid, t)
}
}
if len(valid) == 0 {
delete(rl.attempts, ip)
} else {
bucket.timestamps = valid
}
}
rl.mu.Unlock()
}
}
// RateLimitMiddleware wraps an http.HandlerFunc with rate limiting.
func RateLimitMiddleware(rl *RateLimiter, next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ip := r.RemoteAddr
// Strip port if present.
if idx := len(ip) - 1; idx > 0 {
for i := idx; i >= 0; i-- {
if ip[i] == ':' {
ip = ip[:i]
break
}
}
}
if !rl.Allow(ip) {
w.Header().Set("Retry-After", "60")
writeAuthError(w, http.StatusTooManyRequests, "rate limit exceeded — try again later")
return
}
next(w, r)
}
}

View file

@ -0,0 +1,102 @@
package auth
import (
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestRateLimiter_AllowUnderLimit(t *testing.T) {
rl := NewRateLimiter(5, time.Minute)
for i := 0; i < 5; i++ {
if !rl.Allow("192.168.1.1") {
t.Fatalf("request %d should be allowed", i+1)
}
}
}
func TestRateLimiter_BlockOverLimit(t *testing.T) {
rl := NewRateLimiter(5, time.Minute)
for i := 0; i < 5; i++ {
rl.Allow("192.168.1.1")
}
if rl.Allow("192.168.1.1") {
t.Fatal("6th request should be blocked")
}
}
func TestRateLimiter_DifferentIPs(t *testing.T) {
rl := NewRateLimiter(2, time.Minute)
rl.Allow("10.0.0.1")
rl.Allow("10.0.0.1")
// IP 1 is exhausted.
if rl.Allow("10.0.0.1") {
t.Fatal("IP 10.0.0.1 should be blocked")
}
// IP 2 should still be allowed.
if !rl.Allow("10.0.0.2") {
t.Fatal("IP 10.0.0.2 should be allowed")
}
}
func TestRateLimiter_WindowExpiry(t *testing.T) {
rl := NewRateLimiter(2, 50*time.Millisecond)
rl.Allow("10.0.0.1")
rl.Allow("10.0.0.1")
if rl.Allow("10.0.0.1") {
t.Fatal("should be blocked before window expires")
}
time.Sleep(60 * time.Millisecond)
if !rl.Allow("10.0.0.1") {
t.Fatal("should be allowed after window expires")
}
}
func TestRateLimiter_Reset(t *testing.T) {
rl := NewRateLimiter(2, time.Minute)
rl.Allow("10.0.0.1")
rl.Allow("10.0.0.1")
if rl.Allow("10.0.0.1") {
t.Fatal("should be blocked")
}
rl.Reset("10.0.0.1")
if !rl.Allow("10.0.0.1") {
t.Fatal("should be allowed after reset")
}
}
func TestRateLimitMiddleware_Returns429(t *testing.T) {
rl := NewRateLimiter(1, time.Minute)
handler := RateLimitMiddleware(rl, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// First request — allowed.
req1 := httptest.NewRequest("POST", "/api/auth/login", nil)
req1.RemoteAddr = "192.168.1.1:12345"
w1 := httptest.NewRecorder()
handler(w1, req1)
if w1.Code != http.StatusOK {
t.Fatalf("first request: got %d, want 200", w1.Code)
}
// Second request — blocked.
req2 := httptest.NewRequest("POST", "/api/auth/login", nil)
req2.RemoteAddr = "192.168.1.1:12346"
w2 := httptest.NewRecorder()
handler(w2, req2)
if w2.Code != http.StatusTooManyRequests {
t.Fatalf("second request: got %d, want 429", w2.Code)
}
if w2.Header().Get("Retry-After") != "60" {
t.Fatal("missing Retry-After header")
}
}

View file

@ -0,0 +1,342 @@
package auth
import (
"encoding/json"
"log/slog"
"net/http"
"time"
)
// EmailSendFunc is a callback for sending verification emails.
// Signature: func(toEmail, userName, code string) error
type EmailSendFunc func(toEmail, userName, code string) error
// HandleRegister processes new tenant + owner registration.
// POST /api/auth/register { email, password, name, org_name, org_slug }
// Returns verification_required — user must verify email before login.
// If emailFn is nil, verification code is returned in response (dev mode).
func HandleRegister(userStore *UserStore, tenantStore *TenantStore, jwtSecret []byte, emailFn EmailSendFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req struct {
Email string `json:"email"`
Password string `json:"password"`
Name string `json:"name"`
OrgName string `json:"org_name"`
OrgSlug string `json:"org_slug"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, `{"error":"invalid request body"}`, http.StatusBadRequest)
return
}
if req.Email == "" || req.Password == "" || req.OrgName == "" || req.OrgSlug == "" {
http.Error(w, `{"error":"email, password, org_name, org_slug are required"}`, http.StatusBadRequest)
return
}
if len(req.Password) < 8 {
http.Error(w, `{"error":"password must be at least 8 characters"}`, http.StatusBadRequest)
return
}
if req.Name == "" {
req.Name = req.Email
}
// Create user first (admin of new tenant)
user, err := userStore.CreateUser(req.Email, req.Name, req.Password, "admin")
if err != nil {
if err == ErrUserExists {
http.Error(w, `{"error":"email already registered"}`, http.StatusConflict)
return
}
http.Error(w, `{"error":"failed to create user"}`, http.StatusInternalServerError)
return
}
// Create tenant
tenant, err := tenantStore.CreateTenant(req.OrgName, req.OrgSlug, user.ID, "starter")
if err != nil {
if err == ErrTenantExists {
http.Error(w, `{"error":"organization slug already taken"}`, http.StatusConflict)
return
}
http.Error(w, `{"error":"failed to create organization"}`, http.StatusInternalServerError)
return
}
// Update user with tenant_id
if userStore.db != nil {
userStore.db.Exec(`UPDATE users SET tenant_id = ? WHERE id = ?`, tenant.ID, user.ID)
}
// Generate verification code
code, err := userStore.SetVerifyToken(req.Email)
if err != nil {
http.Error(w, `{"error":"failed to generate verification code"}`, http.StatusInternalServerError)
return
}
// Send verification email if email service is configured
resp := map[string]interface{}{
"status": "verification_required",
"email": req.Email,
"message": "Verification code sent to your email",
"tenant": tenant,
}
if emailFn != nil {
if err := emailFn(req.Email, req.Name, code); err != nil {
slog.Error("failed to send verification email", "email", req.Email, "error", err)
// Still return success — code is in DB, user can retry
}
} else {
// Dev mode — include code in response
resp["verification_code_dev"] = code
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(resp)
}
}
// HandleVerifyEmail validates the verification code and issues JWT.
// POST /api/auth/verify { email, code }
func HandleVerifyEmail(userStore *UserStore, tenantStore *TenantStore, jwtSecret []byte) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req struct {
Email string `json:"email"`
Code string `json:"code"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, `{"error":"invalid request body"}`, http.StatusBadRequest)
return
}
if req.Email == "" || req.Code == "" {
http.Error(w, `{"error":"email and code required"}`, http.StatusBadRequest)
return
}
if err := userStore.VerifyEmail(req.Email, req.Code); err != nil {
if err == ErrInvalidVerifyCode {
http.Error(w, `{"error":"invalid or expired verification code"}`, http.StatusBadRequest)
return
}
http.Error(w, `{"error":"verification failed"}`, http.StatusInternalServerError)
return
}
// Get user and tenant
user, err := userStore.GetByEmail(req.Email)
if err != nil {
http.Error(w, `{"error":"user not found"}`, http.StatusNotFound)
return
}
// Find tenant for this user
var tenantID string
if userStore.db != nil {
userStore.db.QueryRow(`SELECT tenant_id FROM users WHERE id = ?`, user.ID).Scan(&tenantID)
}
// Issue JWT with tenant context
accessToken, err := Sign(Claims{
Sub: user.Email,
Role: user.Role,
TenantID: tenantID,
Exp: time.Now().Add(15 * time.Minute).Unix(),
}, jwtSecret)
if err != nil {
http.Error(w, `{"error":"failed to issue token"}`, http.StatusInternalServerError)
return
}
refreshToken, _ := Sign(Claims{
Sub: user.Email,
Role: user.Role,
TenantID: tenantID,
Exp: time.Now().Add(7 * 24 * time.Hour).Unix(),
}, jwtSecret)
var tenant *Tenant
if tenantID != "" {
tenant, _ = tenantStore.GetTenant(tenantID)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"access_token": accessToken,
"refresh_token": refreshToken,
"expires_in": 900,
"token_type": "Bearer",
"user": user,
"tenant": tenant,
})
}
}
// HandleGetTenant returns the current tenant info.
// GET /api/auth/tenant
func HandleGetTenant(tenantStore *TenantStore) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
claims := GetClaims(r.Context())
if claims == nil || claims.TenantID == "" {
http.Error(w, `{"error":"no tenant context"}`, http.StatusForbidden)
return
}
tenant, err := tenantStore.GetTenant(claims.TenantID)
if err != nil {
http.Error(w, `{"error":"tenant not found"}`, http.StatusNotFound)
return
}
plan := tenant.GetPlan()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"tenant": tenant,
"plan": plan,
"usage": map[string]interface{}{
"events_this_month": tenant.EventsThisMonth,
"events_limit": plan.MaxEventsMonth,
"usage_percent": usagePercent(tenant.EventsThisMonth, plan.MaxEventsMonth),
},
})
}
}
// HandleUpdateTenantPlan upgrades/downgrades the tenant plan.
// POST /api/auth/tenant/plan { plan_id }
func HandleUpdateTenantPlan(tenantStore *TenantStore) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
claims := GetClaims(r.Context())
if claims == nil || claims.Role != "admin" {
http.Error(w, `{"error":"admin role required"}`, http.StatusForbidden)
return
}
var req struct {
PlanID string `json:"plan_id"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, `{"error":"invalid request"}`, http.StatusBadRequest)
return
}
if err := tenantStore.UpdatePlan(claims.TenantID, req.PlanID); err != nil {
http.Error(w, `{"error":"`+err.Error()+`"}`, http.StatusBadRequest)
return
}
tenant, _ := tenantStore.GetTenant(claims.TenantID)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"tenant": tenant,
"plan": tenant.GetPlan(),
})
}
}
// HandleListPlans returns all available pricing plans.
// GET /api/auth/plans
func HandleListPlans() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
plans := make([]Plan, 0, len(DefaultPlans))
order := []string{"starter", "professional", "enterprise"}
for _, id := range order {
if p, ok := DefaultPlans[id]; ok {
plans = append(plans, p)
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{"plans": plans})
}
}
// HandleBillingStatus returns the billing status for the tenant.
// GET /api/auth/billing
func HandleBillingStatus(tenantStore *TenantStore) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
claims := GetClaims(r.Context())
if claims == nil || claims.TenantID == "" {
http.Error(w, `{"error":"no tenant context"}`, http.StatusForbidden)
return
}
tenant, err := tenantStore.GetTenant(claims.TenantID)
if err != nil {
http.Error(w, `{"error":"tenant not found"}`, http.StatusNotFound)
return
}
plan := tenant.GetPlan()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"plan": plan,
"payment_customer_id": tenant.PaymentCustomerID,
"payment_sub_id": tenant.PaymentSubID,
"events_used": tenant.EventsThisMonth,
"events_limit": plan.MaxEventsMonth,
"usage_percent": usagePercent(tenant.EventsThisMonth, plan.MaxEventsMonth),
"next_reset": tenant.MonthResetAt,
})
}
}
// HandleStripeWebhook processes Stripe webhook events.
// POST /api/billing/webhook
func HandleStripeWebhook(tenantStore *TenantStore) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var evt struct {
Type string `json:"type"`
Data struct {
Object struct {
CustomerID string `json:"customer"`
SubscriptionID string `json:"id"`
Status string `json:"status"`
Metadata struct {
TenantID string `json:"tenant_id"`
PlanID string `json:"plan_id"`
} `json:"metadata"`
} `json:"object"`
} `json:"data"`
}
if err := json.NewDecoder(r.Body).Decode(&evt); err != nil {
http.Error(w, "invalid payload", http.StatusBadRequest)
return
}
tenantID := evt.Data.Object.Metadata.TenantID
switch evt.Type {
case "customer.subscription.created", "customer.subscription.updated":
if tenantID != "" {
tenantStore.SetStripeIDs(tenantID,
evt.Data.Object.CustomerID,
evt.Data.Object.SubscriptionID)
if planID := evt.Data.Object.Metadata.PlanID; planID != "" {
tenantStore.UpdatePlan(tenantID, planID)
}
}
case "customer.subscription.deleted":
if tenantID != "" {
tenantStore.UpdatePlan(tenantID, "starter")
tenantStore.SetStripeIDs(tenantID, evt.Data.Object.CustomerID, "")
}
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"received":true}`))
}
}
func usagePercent(used, limit int) float64 {
if limit <= 0 {
return 0
}
pct := float64(used) / float64(limit) * 100
if pct > 100 {
return 100
}
return pct
}

View file

@ -0,0 +1,322 @@
package auth
import (
"database/sql"
"errors"
"fmt"
"log/slog"
"sync"
"time"
)
// Standard tenant errors.
var (
ErrTenantNotFound = errors.New("auth: tenant not found")
ErrTenantExists = errors.New("auth: tenant already exists")
ErrQuotaExceeded = errors.New("auth: plan quota exceeded")
)
// Plan represents a subscription tier with resource limits.
type Plan struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
MaxUsers int `json:"max_users"`
MaxEventsMonth int `json:"max_events_month"`
MaxIncidents int `json:"max_incidents"`
MaxSensors int `json:"max_sensors"`
RetentionDays int `json:"retention_days"`
SLAEnabled bool `json:"sla_enabled"`
SOAREnabled bool `json:"soar_enabled"`
ComplianceEnabled bool `json:"compliance_enabled"`
OnPremise bool `json:"on_premise"` // Enterprise: on-premise deployment
PriceMonthCents int `json:"price_month_cents"` // 0 = free, -1 = custom pricing
}
// DefaultPlans defines the standard pricing tiers (prices in RUB kopecks).
var DefaultPlans = map[string]Plan{
"starter": {
ID: "starter", Name: "Starter",
Description: "AI-мониторинг: до 5 сенсоров, базовая корреляция и алерты",
MaxUsers: 10, MaxEventsMonth: 100000, MaxIncidents: 200, MaxSensors: 5,
RetentionDays: 30, SLAEnabled: true, SOAREnabled: false, ComplianceEnabled: false,
PriceMonthCents: 8990000, // 89 900 ₽/мес
},
"professional": {
ID: "professional", Name: "Professional",
Description: "Полный AI SOC: SOAR, compliance, расширенная аналитика",
MaxUsers: 50, MaxEventsMonth: 500000, MaxIncidents: 1000, MaxSensors: 25,
RetentionDays: 90, SLAEnabled: true, SOAREnabled: true, ComplianceEnabled: true,
PriceMonthCents: 14990000, // 149 900 ₽/мес
},
"enterprise": {
ID: "enterprise", Name: "Enterprise",
Description: "On-premise / выделенный инстанс. Сертификация — на стороне заказчика",
MaxUsers: -1, MaxEventsMonth: -1, MaxIncidents: -1, MaxSensors: -1,
RetentionDays: 365, SLAEnabled: true, SOAREnabled: true, ComplianceEnabled: true,
OnPremise: true,
PriceMonthCents: -1, // по запросу
},
}
// Tenant represents an isolated organization in the multi-tenant system.
type Tenant struct {
ID string `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
PlanID string `json:"plan_id"`
PaymentCustomerID string `json:"payment_customer_id,omitempty"`
PaymentSubID string `json:"payment_sub_id,omitempty"`
OwnerUserID string `json:"owner_user_id"`
Active bool `json:"active"`
CreatedAt time.Time `json:"created_at"`
EventsThisMonth int `json:"events_this_month"`
MonthResetAt time.Time `json:"month_reset_at"`
}
// GetPlan returns the tenant's plan configuration.
func (t *Tenant) GetPlan() Plan {
if p, ok := DefaultPlans[t.PlanID]; ok {
return p
}
return DefaultPlans["starter"]
}
// CanIngestEvent checks if the tenant can still ingest events this month.
func (t *Tenant) CanIngestEvent() bool {
plan := t.GetPlan()
if plan.MaxEventsMonth < 0 {
return true // unlimited
}
return t.EventsThisMonth < plan.MaxEventsMonth
}
// TenantStore manages tenant records backed by SQLite.
type TenantStore struct {
mu sync.RWMutex
db *sql.DB
tenants map[string]*Tenant // id -> Tenant
}
// NewTenantStore creates a tenant store.
func NewTenantStore(db *sql.DB) *TenantStore {
s := &TenantStore{
db: db,
tenants: make(map[string]*Tenant),
}
if db != nil {
if err := s.migrate(); err != nil {
slog.Error("tenant store: migration failed", "error", err)
} else {
s.loadFromDB()
}
}
return s
}
func (s *TenantStore) migrate() error {
_, err := s.db.Exec(`
CREATE TABLE IF NOT EXISTS tenants (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
slug TEXT UNIQUE NOT NULL,
plan_id TEXT NOT NULL DEFAULT 'free',
stripe_customer_id TEXT DEFAULT '',
stripe_sub_id TEXT DEFAULT '',
owner_user_id TEXT NOT NULL,
active INTEGER NOT NULL DEFAULT 1,
created_at TEXT NOT NULL,
events_this_month INTEGER NOT NULL DEFAULT 0,
month_reset_at TEXT NOT NULL
);
-- Add tenant_id to users table if not exists
-- SQLite doesn't support ADD COLUMN IF NOT EXISTS, so we use a trick
`)
if err != nil {
return err
}
// Add tenant_id column to users if missing
_, _ = s.db.Exec(`ALTER TABLE users ADD COLUMN tenant_id TEXT DEFAULT ''`)
return nil
}
func (s *TenantStore) loadFromDB() {
rows, err := s.db.Query(`SELECT id, name, slug, plan_id, stripe_customer_id, stripe_sub_id,
owner_user_id, active, created_at, events_this_month, month_reset_at FROM tenants`)
if err != nil {
slog.Error("load tenants from DB", "error", err)
return
}
defer rows.Close()
s.mu.Lock()
defer s.mu.Unlock()
for rows.Next() {
var t Tenant
var createdAt, monthReset string
if err := rows.Scan(&t.ID, &t.Name, &t.Slug, &t.PlanID, &t.PaymentCustomerID,
&t.PaymentSubID, &t.OwnerUserID, &t.Active, &createdAt, &t.EventsThisMonth, &monthReset); err != nil {
continue
}
t.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
t.MonthResetAt, _ = time.Parse(time.RFC3339, monthReset)
s.tenants[t.ID] = &t
}
slog.Info("tenants loaded from DB", "count", len(s.tenants))
}
func (s *TenantStore) persistTenant(t *Tenant) {
if s.db == nil {
return
}
_, err := s.db.Exec(`
INSERT OR REPLACE INTO tenants (id, name, slug, plan_id, stripe_customer_id, stripe_sub_id,
owner_user_id, active, created_at, events_this_month, month_reset_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
t.ID, t.Name, t.Slug, t.PlanID, t.PaymentCustomerID, t.PaymentSubID,
t.OwnerUserID, t.Active, t.CreatedAt.Format(time.RFC3339),
t.EventsThisMonth, t.MonthResetAt.Format(time.RFC3339),
)
if err != nil {
slog.Error("persist tenant", "id", t.ID, "error", err)
}
}
// CreateTenant creates a new tenant and assigns an owner.
func (s *TenantStore) CreateTenant(name, slug, ownerUserID, planID string) (*Tenant, error) {
s.mu.Lock()
defer s.mu.Unlock()
for _, t := range s.tenants {
if t.Slug == slug {
return nil, ErrTenantExists
}
}
if _, ok := DefaultPlans[planID]; !ok {
planID = "starter"
}
t := &Tenant{
ID: generateID("tnt"),
Name: name,
Slug: slug,
PlanID: planID,
OwnerUserID: ownerUserID,
Active: true,
CreatedAt: time.Now(),
EventsThisMonth: 0,
MonthResetAt: monthStart(time.Now().AddDate(0, 1, 0)),
}
s.tenants[t.ID] = t
go s.persistTenant(t)
return t, nil
}
// GetTenant returns a tenant by ID.
func (s *TenantStore) GetTenant(id string) (*Tenant, error) {
s.mu.RLock()
defer s.mu.RUnlock()
t, ok := s.tenants[id]
if !ok {
return nil, ErrTenantNotFound
}
return t, nil
}
// GetTenantBySlug returns a tenant by slug.
func (s *TenantStore) GetTenantBySlug(slug string) (*Tenant, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, t := range s.tenants {
if t.Slug == slug {
return t, nil
}
}
return nil, ErrTenantNotFound
}
// ListTenants returns all tenants.
func (s *TenantStore) ListTenants() []*Tenant {
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]*Tenant, 0, len(s.tenants))
for _, t := range s.tenants {
result = append(result, t)
}
return result
}
// UpdatePlan changes a tenant's plan.
func (s *TenantStore) UpdatePlan(tenantID, planID string) error {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.tenants[tenantID]
if !ok {
return ErrTenantNotFound
}
if _, valid := DefaultPlans[planID]; !valid {
return fmt.Errorf("auth: unknown plan %q", planID)
}
t.PlanID = planID
go s.persistTenant(t)
return nil
}
// SetStripeIDs saves Stripe customer + subscription IDs.
func (s *TenantStore) SetStripeIDs(tenantID, customerID, subID string) error {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.tenants[tenantID]
if !ok {
return ErrTenantNotFound
}
t.PaymentCustomerID = customerID
t.PaymentSubID = subID
go s.persistTenant(t)
return nil
}
// IncrementEvents increments the monthly event counter. Returns error if quota exceeded.
func (s *TenantStore) IncrementEvents(tenantID string, count int) error {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.tenants[tenantID]
if !ok {
return ErrTenantNotFound
}
// Auto-reset if past the reset date
if time.Now().After(t.MonthResetAt) {
t.EventsThisMonth = 0
t.MonthResetAt = monthStart(time.Now().AddDate(0, 1, 0))
}
plan := t.GetPlan()
if plan.MaxEventsMonth >= 0 && t.EventsThisMonth+count > plan.MaxEventsMonth {
return ErrQuotaExceeded
}
t.EventsThisMonth += count
go s.persistTenant(t)
return nil
}
// DeactivateTenant marks a tenant as inactive (subscription cancelled).
func (s *TenantStore) DeactivateTenant(tenantID string) error {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.tenants[tenantID]
if !ok {
return ErrTenantNotFound
}
t.Active = false
go s.persistTenant(t)
return nil
}
func monthStart(t time.Time) time.Time {
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, t.Location())
}

View file

@ -0,0 +1,485 @@
package auth
import (
"crypto/rand"
"crypto/sha256"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"log/slog"
"sync"
"time"
"golang.org/x/crypto/bcrypt"
)
// Standard user errors.
var (
ErrUserNotFound = errors.New("auth: user not found")
ErrUserExists = errors.New("auth: user already exists")
ErrInvalidPassword = errors.New("auth: invalid password")
ErrUserDisabled = errors.New("auth: account disabled")
ErrEmailNotVerified = errors.New("auth: email not verified")
ErrInvalidVerifyCode = errors.New("auth: invalid or expired verification code")
)
// User represents an authenticated user in the system.
type User struct {
ID string `json:"id"`
Email string `json:"email"`
DisplayName string `json:"display_name"`
Role string `json:"role"` // admin, analyst, viewer
Active bool `json:"active"`
EmailVerified bool `json:"email_verified"`
PasswordHash string `json:"-"` // never serialized
VerifyToken string `json:"-"` // never serialized
VerifyExpiry *time.Time `json:"-"` // never serialized
CreatedAt time.Time `json:"created_at"`
LastLoginAt *time.Time `json:"last_login_at,omitempty"`
}
// UserStore manages user credentials backed by SQLite.
// Falls back to in-memory store if no DB is provided.
type UserStore struct {
mu sync.RWMutex
db *sql.DB
users map[string]*User // email -> User (in-memory cache / fallback)
}
// NewUserStore creates a user store. If db is nil, uses in-memory only.
func NewUserStore(db ...*sql.DB) *UserStore {
s := &UserStore{
users: make(map[string]*User),
}
if len(db) > 0 && db[0] != nil {
s.db = db[0]
if err := s.migrate(); err != nil {
slog.Error("user store: migration failed", "error", err)
} else {
s.loadFromDB()
}
}
// Ensure default admin exists
if _, err := s.GetByEmail("admin@xn--80akacl3adqr.xn--p1acf"); err != nil {
hash, _ := bcrypt.GenerateFromPassword([]byte("syntrex-admin-2026"), bcrypt.DefaultCost)
admin := &User{
ID: generateID("usr"),
Email: "admin@xn--80akacl3adqr.xn--p1acf",
DisplayName: "Administrator",
Role: "admin",
Active: true,
EmailVerified: true, // default admin is pre-verified
PasswordHash: string(hash),
CreatedAt: time.Now(),
}
s.mu.Lock()
s.users[admin.Email] = admin
s.mu.Unlock()
if s.db != nil {
s.persistUser(admin)
}
slog.Info("default admin created", "email", admin.Email)
}
return s
}
// migrate creates the users table if not exists.
func (s *UserStore) migrate() error {
_, err := s.db.Exec(`
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
email TEXT UNIQUE NOT NULL,
display_name TEXT NOT NULL DEFAULT '',
role TEXT NOT NULL DEFAULT 'viewer',
active INTEGER NOT NULL DEFAULT 1,
email_verified INTEGER NOT NULL DEFAULT 0,
password_hash TEXT NOT NULL,
verify_token TEXT DEFAULT '',
verify_expiry TEXT DEFAULT '',
created_at TEXT NOT NULL,
last_login_at TEXT
);
CREATE TABLE IF NOT EXISTS api_keys (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL REFERENCES users(id),
key_hash TEXT NOT NULL,
name TEXT NOT NULL DEFAULT '',
role TEXT NOT NULL DEFAULT 'viewer',
created_at TEXT NOT NULL,
last_used TEXT
);
`)
if err != nil {
return err
}
// Add columns if upgrading from older schema
s.db.Exec(`ALTER TABLE users ADD COLUMN email_verified INTEGER NOT NULL DEFAULT 0`)
s.db.Exec(`ALTER TABLE users ADD COLUMN verify_token TEXT DEFAULT ''`)
s.db.Exec(`ALTER TABLE users ADD COLUMN verify_expiry TEXT DEFAULT ''`)
return nil
}
// loadFromDB loads all users from SQLite into memory cache.
func (s *UserStore) loadFromDB() {
rows, err := s.db.Query(`SELECT id, email, display_name, role, active, password_hash, created_at, last_login_at FROM users`)
if err != nil {
slog.Error("load users from DB", "error", err)
return
}
defer rows.Close()
s.mu.Lock()
defer s.mu.Unlock()
for rows.Next() {
var u User
var createdAt string
var lastLogin sql.NullString
if err := rows.Scan(&u.ID, &u.Email, &u.DisplayName, &u.Role, &u.Active, &u.PasswordHash, &createdAt, &lastLogin); err != nil {
continue
}
u.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
if lastLogin.Valid {
t, _ := time.Parse(time.RFC3339, lastLogin.String)
u.LastLoginAt = &t
}
s.users[u.Email] = &u
}
slog.Info("users loaded from DB", "count", len(s.users))
}
// persistUser writes a user to SQLite.
func (s *UserStore) persistUser(u *User) {
if s.db == nil {
return
}
var lastLogin *string
if u.LastLoginAt != nil {
t := u.LastLoginAt.Format(time.RFC3339)
lastLogin = &t
}
var verifyExpiry string
if u.VerifyExpiry != nil {
verifyExpiry = u.VerifyExpiry.Format(time.RFC3339)
}
verified := 0
if u.EmailVerified {
verified = 1
}
_, err := s.db.Exec(`
INSERT OR REPLACE INTO users (id, email, display_name, role, active, email_verified, password_hash, verify_token, verify_expiry, created_at, last_login_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
u.ID, u.Email, u.DisplayName, u.Role, u.Active, verified, u.PasswordHash, u.VerifyToken, verifyExpiry, u.CreatedAt.Format(time.RFC3339), lastLogin,
)
if err != nil {
slog.Error("persist user", "email", u.Email, "error", err)
}
}
// --- CRUD Operations ---
// CreateUser creates a new user with a hashed password.
func (s *UserStore) CreateUser(email, displayName, password, role string) (*User, error) {
s.mu.Lock()
defer s.mu.Unlock()
if _, exists := s.users[email]; exists {
return nil, ErrUserExists
}
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("auth: hash password: %w", err)
}
u := &User{
ID: generateID("usr"),
Email: email,
DisplayName: displayName,
Role: role,
Active: true,
PasswordHash: string(hash),
CreatedAt: time.Now(),
}
s.users[email] = u
go s.persistUser(u)
return u, nil
}
// Authenticate validates email/password and returns the user.
func (s *UserStore) Authenticate(email, password string) (*User, error) {
s.mu.RLock()
user, ok := s.users[email]
s.mu.RUnlock()
if !ok {
return nil, ErrUserNotFound
}
if !user.Active {
return nil, ErrUserDisabled
}
if !user.EmailVerified {
return nil, ErrEmailNotVerified
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil {
return nil, ErrInvalidPassword
}
// Update last login
now := time.Now()
s.mu.Lock()
user.LastLoginAt = &now
s.mu.Unlock()
go s.persistUser(user)
return user, nil
}
// SetVerifyToken generates a 6-digit verification code for a user.
func (s *UserStore) SetVerifyToken(email string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
user, ok := s.users[email]
if !ok {
return "", ErrUserNotFound
}
// Generate 6-digit code
b := make([]byte, 3)
rand.Read(b)
code := fmt.Sprintf("%06d", int(b[0])<<16|int(b[1])<<8|int(b[2])%1000000)
if len(code) > 6 {
code = code[:6]
}
expiry := time.Now().Add(24 * time.Hour)
user.VerifyToken = code
user.VerifyExpiry = &expiry
go s.persistUser(user)
return code, nil
}
// VerifyEmail checks the verification code and marks email as verified.
func (s *UserStore) VerifyEmail(email, code string) error {
s.mu.Lock()
defer s.mu.Unlock()
user, ok := s.users[email]
if !ok {
return ErrUserNotFound
}
if user.VerifyToken == "" || user.VerifyToken != code {
return ErrInvalidVerifyCode
}
if user.VerifyExpiry != nil && time.Now().After(*user.VerifyExpiry) {
return ErrInvalidVerifyCode
}
user.EmailVerified = true
user.VerifyToken = ""
user.VerifyExpiry = nil
go s.persistUser(user)
return nil
}
// GetByEmail returns a user by email.
func (s *UserStore) GetByEmail(email string) (*User, error) {
s.mu.RLock()
defer s.mu.RUnlock()
user, ok := s.users[email]
if !ok {
return nil, ErrUserNotFound
}
return user, nil
}
// GetByID returns a user by ID.
func (s *UserStore) GetByID(id string) (*User, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, u := range s.users {
if u.ID == id {
return u, nil
}
}
return nil, ErrUserNotFound
}
// ListUsers returns all users.
func (s *UserStore) ListUsers() []*User {
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]*User, 0, len(s.users))
for _, u := range s.users {
result = append(result, u)
}
return result
}
// UpdateUser updates a user's display name, role, and active status.
func (s *UserStore) UpdateUser(id, displayName, role string, active bool) error {
s.mu.Lock()
defer s.mu.Unlock()
for _, u := range s.users {
if u.ID == id {
if displayName != "" {
u.DisplayName = displayName
}
if role != "" {
u.Role = role
}
u.Active = active
go s.persistUser(u)
return nil
}
}
return ErrUserNotFound
}
// ChangePassword updates a user's password.
func (s *UserStore) ChangePassword(id, newPassword string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("auth: hash password: %w", err)
}
s.mu.Lock()
defer s.mu.Unlock()
for _, u := range s.users {
if u.ID == id {
u.PasswordHash = string(hash)
go s.persistUser(u)
return nil
}
}
return ErrUserNotFound
}
// DeleteUser permanently removes a user.
func (s *UserStore) DeleteUser(id string) error {
s.mu.Lock()
defer s.mu.Unlock()
for email, u := range s.users {
if u.ID == id {
delete(s.users, email)
if s.db != nil {
go s.db.Exec(`DELETE FROM users WHERE id = ?`, id)
}
return nil
}
}
return ErrUserNotFound
}
// --- API Key Management ---
// APIKey represents an API key for programmatic access.
type APIKey struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Name string `json:"name"`
Role string `json:"role"`
KeyPrefix string `json:"key_prefix"` // first 8 chars for display
CreatedAt time.Time `json:"created_at"`
LastUsed *time.Time `json:"last_used,omitempty"`
}
// CreateAPIKey generates a new API key for a user. Returns the full key (only shown once).
func (s *UserStore) CreateAPIKey(userID, name, role string) (string, *APIKey, error) {
rawKey := make([]byte, 32)
if _, err := rand.Read(rawKey); err != nil {
return "", nil, err
}
fullKey := "stx_" + hex.EncodeToString(rawKey)
keyHash := hashKey(fullKey)
ak := &APIKey{
ID: generateID("key"),
UserID: userID,
Name: name,
Role: role,
KeyPrefix: fullKey[:12],
CreatedAt: time.Now(),
}
if s.db != nil {
_, err := s.db.Exec(`INSERT INTO api_keys (id, user_id, key_hash, name, role, created_at) VALUES (?,?,?,?,?,?)`,
ak.ID, ak.UserID, keyHash, ak.Name, ak.Role, ak.CreatedAt.Format(time.RFC3339))
if err != nil {
return "", nil, err
}
}
return fullKey, ak, nil
}
// ValidateAPIKey checks an API key and returns the associated role.
func (s *UserStore) ValidateAPIKey(key string) (string, string, error) {
if s.db == nil {
return "", "", fmt.Errorf("no database for API keys")
}
keyHash := hashKey(key)
var userID, role string
err := s.db.QueryRow(`SELECT user_id, role FROM api_keys WHERE key_hash = ?`, keyHash).Scan(&userID, &role)
if err != nil {
return "", "", ErrInvalidToken
}
// Update last_used
go s.db.Exec(`UPDATE api_keys SET last_used = ? WHERE key_hash = ?`, time.Now().Format(time.RFC3339), keyHash)
return userID, role, nil
}
// ListAPIKeys returns all API keys for a user.
func (s *UserStore) ListAPIKeys(userID string) ([]APIKey, error) {
if s.db == nil {
return nil, nil
}
rows, err := s.db.Query(`SELECT id, user_id, name, role, created_at, last_used FROM api_keys WHERE user_id = ?`, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var keys []APIKey
for rows.Next() {
var ak APIKey
var createdAt string
var lastUsed sql.NullString
if err := rows.Scan(&ak.ID, &ak.UserID, &ak.Name, &ak.Role, &createdAt, &lastUsed); err != nil {
continue
}
ak.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
if lastUsed.Valid {
t, _ := time.Parse(time.RFC3339, lastUsed.String)
ak.LastUsed = &t
}
keys = append(keys, ak)
}
return keys, nil
}
// DeleteAPIKey revokes an API key.
func (s *UserStore) DeleteAPIKey(keyID, userID string) error {
if s.db == nil {
return nil
}
_, err := s.db.Exec(`DELETE FROM api_keys WHERE id = ? AND user_id = ?`, keyID, userID)
return err
}
// --- Helpers ---
func generateID(prefix string) string {
b := make([]byte, 8)
rand.Read(b)
return fmt.Sprintf("%s-%s", prefix, hex.EncodeToString(b))
}
func hashKey(key string) string {
h := sha256.Sum256([]byte(key))
return hex.EncodeToString(h[:])
}