package auth import ( "encoding/json" "log/slog" "net/http" "os" "regexp" "time" ) // htmlTagRegex strips HTML/script tags from user input (M5 XSS prevention). var htmlTagRegex = regexp.MustCompile(`<[^>]*>`) // EmailSendFunc is a callback for sending verification emails. // Signature: func(toEmail, userName, code string) error type EmailSendFunc func(toEmail, userName, code string) error // HandleRegister processes new tenant + owner registration. // POST /api/auth/register { email, password, name, org_name, org_slug } // Returns verification_required — user must verify email before login. // If emailFn is nil, verification code is returned in response (dev mode). func HandleRegister(userStore *UserStore, tenantStore *TenantStore, jwtSecret []byte, emailFn EmailSendFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // SEC-M4: Server-side registration gate if os.Getenv("SOC_REGISTRATION_OPEN") == "false" { http.Error(w, `{"error":"registration is closed — contact admin for an invitation"}`, http.StatusForbidden) return } var req struct { Email string `json:"email"` Password string `json:"password"` Name string `json:"name"` OrgName string `json:"org_name"` OrgSlug string `json:"org_slug"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, `{"error":"invalid request body"}`, http.StatusBadRequest) return } if req.Email == "" || req.Password == "" || req.OrgName == "" || req.OrgSlug == "" { http.Error(w, `{"error":"email, password, org_name, org_slug are required"}`, http.StatusBadRequest) return } if len(req.Password) < 8 { http.Error(w, `{"error":"password must be at least 8 characters"}`, http.StatusBadRequest) return } if req.Name == "" { req.Name = req.Email } // SEC-M5: Strip HTML tags from user input to prevent stored XSS req.Name = htmlTagRegex.ReplaceAllString(req.Name, "") req.OrgName = htmlTagRegex.ReplaceAllString(req.OrgName, "") // Create user first (admin of new tenant) user, err := userStore.CreateUser(req.Email, req.Name, req.Password, "admin") if err != nil { if err == ErrUserExists { http.Error(w, `{"error":"email already registered"}`, http.StatusConflict) return } http.Error(w, `{"error":"failed to create user"}`, http.StatusInternalServerError) return } // Create tenant tenant, err := tenantStore.CreateTenant(req.OrgName, req.OrgSlug, user.ID, "free") if err != nil { if err == ErrTenantExists { http.Error(w, `{"error":"organization slug already taken"}`, http.StatusConflict) return } http.Error(w, `{"error":"failed to create organization"}`, http.StatusInternalServerError) return } // Update user with tenant_id if userStore.db != nil { userStore.db.Exec(`UPDATE users SET tenant_id = ? WHERE id = ?`, tenant.ID, user.ID) } // Generate verification code code, err := userStore.SetVerifyToken(req.Email) if err != nil { http.Error(w, `{"error":"failed to generate verification code"}`, http.StatusInternalServerError) return } // Send verification email if email service is configured resp := map[string]interface{}{ "status": "verification_required", "email": req.Email, "message": "Verification code sent to your email", "tenant": tenant, } if emailFn != nil { if err := emailFn(req.Email, req.Name, code); err != nil { slog.Error("failed to send verification email", "email", req.Email, "error", err) // Still return success — code is in DB, user can retry } } else { // SEC: Never expose verification code in API response. // Log server-side only for development debugging. slog.Warn("email service not configured — verification code logged (dev only)", "email", req.Email, "code", code) } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) json.NewEncoder(w).Encode(resp) } } // HandleVerifyEmail validates the verification code and issues JWT. // POST /api/auth/verify { email, code } func HandleVerifyEmail(userStore *UserStore, tenantStore *TenantStore, jwtSecret []byte) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var req struct { Email string `json:"email"` Code string `json:"code"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, `{"error":"invalid request body"}`, http.StatusBadRequest) return } if req.Email == "" || req.Code == "" { http.Error(w, `{"error":"email and code required"}`, http.StatusBadRequest) return } if err := userStore.VerifyEmail(req.Email, req.Code); err != nil { if err == ErrInvalidVerifyCode { http.Error(w, `{"error":"invalid or expired verification code"}`, http.StatusBadRequest) return } http.Error(w, `{"error":"verification failed"}`, http.StatusInternalServerError) return } // Get user and tenant user, err := userStore.GetByEmail(req.Email) if err != nil { http.Error(w, `{"error":"user not found"}`, http.StatusNotFound) return } // Find tenant for this user var tenantID string if userStore.db != nil { userStore.db.QueryRow(`SELECT tenant_id FROM users WHERE id = ?`, user.ID).Scan(&tenantID) } // Issue JWT with tenant context accessToken, err := Sign(Claims{ Sub: user.Email, Role: user.Role, TenantID: tenantID, TokenType: "access", Exp: time.Now().Add(15 * time.Minute).Unix(), }, jwtSecret) if err != nil { http.Error(w, `{"error":"failed to issue token"}`, http.StatusInternalServerError) return } refreshToken, _ := Sign(Claims{ Sub: user.Email, Role: user.Role, TenantID: tenantID, TokenType: "refresh", Exp: time.Now().Add(7 * 24 * time.Hour).Unix(), }, jwtSecret) var tenant *Tenant if tenantID != "" { tenant, _ = tenantStore.GetTenant(tenantID) } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{ "access_token": accessToken, "refresh_token": refreshToken, "expires_in": 900, "token_type": "Bearer", "user": user, "tenant": tenant, }) } } // HandleGetTenant returns the current tenant info. // GET /api/auth/tenant func HandleGetTenant(tenantStore *TenantStore) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { claims := GetClaims(r.Context()) if claims == nil || claims.TenantID == "" { http.Error(w, `{"error":"no tenant context"}`, http.StatusForbidden) return } tenant, err := tenantStore.GetTenant(claims.TenantID) if err != nil { http.Error(w, `{"error":"tenant not found"}`, http.StatusNotFound) return } plan := tenant.GetPlan() w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{ "tenant": tenant, "plan": plan, "usage": map[string]interface{}{ "events_this_month": tenant.EventsThisMonth, "events_limit": plan.MaxEventsMonth, "usage_percent": usagePercent(tenant.EventsThisMonth, plan.MaxEventsMonth), }, }) } } // HandleUpdateTenantPlan upgrades/downgrades the tenant plan. // POST /api/auth/tenant/plan { plan_id } func HandleUpdateTenantPlan(tenantStore *TenantStore) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { claims := GetClaims(r.Context()) if claims == nil || claims.Role != "admin" { http.Error(w, `{"error":"admin role required"}`, http.StatusForbidden) return } var req struct { PlanID string `json:"plan_id"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, `{"error":"invalid request"}`, http.StatusBadRequest) return } if err := tenantStore.UpdatePlan(claims.TenantID, req.PlanID); err != nil { http.Error(w, `{"error":"`+err.Error()+`"}`, http.StatusBadRequest) return } tenant, _ := tenantStore.GetTenant(claims.TenantID) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{ "tenant": tenant, "plan": tenant.GetPlan(), }) } } // HandleListPlans returns all available pricing plans. // GET /api/auth/plans func HandleListPlans() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { plans := make([]Plan, 0, len(DefaultPlans)) order := []string{"free", "starter", "professional", "enterprise"} for _, id := range order { if p, ok := DefaultPlans[id]; ok { plans = append(plans, p) } } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{"plans": plans}) } } // HandleBillingStatus returns the billing status for the tenant. // GET /api/auth/billing func HandleBillingStatus(tenantStore *TenantStore) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { claims := GetClaims(r.Context()) if claims == nil || claims.TenantID == "" { http.Error(w, `{"error":"no tenant context"}`, http.StatusForbidden) return } tenant, err := tenantStore.GetTenant(claims.TenantID) if err != nil { http.Error(w, `{"error":"tenant not found"}`, http.StatusNotFound) return } plan := tenant.GetPlan() w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{ "plan": plan, "payment_customer_id": tenant.PaymentCustomerID, "payment_sub_id": tenant.PaymentSubID, "events_used": tenant.EventsThisMonth, "events_limit": plan.MaxEventsMonth, "usage_percent": usagePercent(tenant.EventsThisMonth, plan.MaxEventsMonth), "next_reset": tenant.MonthResetAt, }) } } // HandleStripeWebhook processes Stripe webhook events. // POST /api/billing/webhook func HandleStripeWebhook(tenantStore *TenantStore) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var evt struct { Type string `json:"type"` Data struct { Object struct { CustomerID string `json:"customer"` SubscriptionID string `json:"id"` Status string `json:"status"` Metadata struct { TenantID string `json:"tenant_id"` PlanID string `json:"plan_id"` } `json:"metadata"` } `json:"object"` } `json:"data"` } if err := json.NewDecoder(r.Body).Decode(&evt); err != nil { http.Error(w, "invalid payload", http.StatusBadRequest) return } tenantID := evt.Data.Object.Metadata.TenantID switch evt.Type { case "customer.subscription.created", "customer.subscription.updated": if tenantID != "" { tenantStore.SetStripeIDs(tenantID, evt.Data.Object.CustomerID, evt.Data.Object.SubscriptionID) if planID := evt.Data.Object.Metadata.PlanID; planID != "" { tenantStore.UpdatePlan(tenantID, planID) } } case "customer.subscription.deleted": if tenantID != "" { tenantStore.UpdatePlan(tenantID, "starter") tenantStore.SetStripeIDs(tenantID, evt.Data.Object.CustomerID, "") } } w.WriteHeader(http.StatusOK) w.Write([]byte(`{"received":true}`)) } } func usagePercent(used, limit int) float64 { if limit <= 0 { return 0 } pct := float64(used) / float64(limit) * 100 if pct > 100 { return 100 } return pct }