gomcp/internal/infrastructure/auth/tenant_handlers.go

379 lines
12 KiB
Go

package auth
import (
"encoding/json"
"log/slog"
"net/http"
"os"
"regexp"
"time"
)
// htmlTagRegex strips HTML/script tags from user input (M5 XSS prevention).
var htmlTagRegex = regexp.MustCompile(`<[^>]*>`)
// EmailSendFunc is a callback for sending verification emails.
// Signature: func(toEmail, userName, code string) error
type EmailSendFunc func(toEmail, userName, code string) error
// HandleRegister processes new tenant + owner registration.
// POST /api/auth/register { email, password, name, org_name, org_slug }
// Returns verification_required — user must verify email before login.
// If emailFn is nil, verification code is returned in response (dev mode).
func HandleRegister(userStore *UserStore, tenantStore *TenantStore, jwtSecret []byte, emailFn EmailSendFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// SEC-M4: Server-side registration gate
if os.Getenv("SOC_REGISTRATION_OPEN") == "false" {
http.Error(w, `{"error":"registration is closed — contact admin for an invitation"}`, http.StatusForbidden)
return
}
var req struct {
Email string `json:"email"`
Password string `json:"password"`
Name string `json:"name"`
OrgName string `json:"org_name"`
OrgSlug string `json:"org_slug"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, `{"error":"invalid request body"}`, http.StatusBadRequest)
return
}
if req.Email == "" || req.Password == "" || req.OrgName == "" || req.OrgSlug == "" {
http.Error(w, `{"error":"email, password, org_name, org_slug are required"}`, http.StatusBadRequest)
return
}
if len(req.Password) < 8 {
http.Error(w, `{"error":"password must be at least 8 characters"}`, http.StatusBadRequest)
return
}
if req.Name == "" {
req.Name = req.Email
}
// SEC-M5: Strip HTML tags from user input to prevent stored XSS
req.Name = htmlTagRegex.ReplaceAllString(req.Name, "")
req.OrgName = htmlTagRegex.ReplaceAllString(req.OrgName, "")
// Create user first (admin of new tenant)
user, err := userStore.CreateUser(req.Email, req.Name, req.Password, "admin")
if err != nil {
if err == ErrUserExists {
http.Error(w, `{"error":"email already registered"}`, http.StatusConflict)
return
}
http.Error(w, `{"error":"failed to create user"}`, http.StatusInternalServerError)
return
}
// Create tenant
tenant, err := tenantStore.CreateTenant(req.OrgName, req.OrgSlug, user.ID, "free")
if err != nil {
if err == ErrTenantExists {
http.Error(w, `{"error":"organization slug already taken"}`, http.StatusConflict)
return
}
http.Error(w, `{"error":"failed to create organization"}`, http.StatusInternalServerError)
return
}
// Update user with tenant_id
if userStore.db != nil {
userStore.db.Exec(`UPDATE users SET tenant_id = ? WHERE id = ?`, tenant.ID, user.ID)
}
// Generate verification code
code, err := userStore.SetVerifyToken(req.Email)
if err != nil {
http.Error(w, `{"error":"failed to generate verification code"}`, http.StatusInternalServerError)
return
}
// Send verification email if email service is configured
resp := map[string]interface{}{
"status": "verification_required",
"email": req.Email,
"message": "Verification code sent to your email",
"tenant": tenant,
}
if emailFn != nil {
if err := emailFn(req.Email, req.Name, code); err != nil {
slog.Error("failed to send verification email", "email", req.Email, "error", err)
// Still return success — code is in DB, user can retry
}
} else {
// SEC: Never expose verification code in API response.
// Log server-side only for development debugging.
slog.Warn("email service not configured — verification code logged (dev only)",
"email", req.Email, "code", code)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(resp)
}
}
// HandleVerifyEmail validates the verification code and issues JWT.
// POST /api/auth/verify { email, code }
func HandleVerifyEmail(userStore *UserStore, tenantStore *TenantStore, jwtSecret []byte) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req struct {
Email string `json:"email"`
Code string `json:"code"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, `{"error":"invalid request body"}`, http.StatusBadRequest)
return
}
if req.Email == "" || req.Code == "" {
http.Error(w, `{"error":"email and code required"}`, http.StatusBadRequest)
return
}
if err := userStore.VerifyEmail(req.Email, req.Code); err != nil {
if err == ErrInvalidVerifyCode {
http.Error(w, `{"error":"invalid or expired verification code"}`, http.StatusBadRequest)
return
}
http.Error(w, `{"error":"verification failed"}`, http.StatusInternalServerError)
return
}
// Get user and tenant
user, err := userStore.GetByEmail(req.Email)
if err != nil {
http.Error(w, `{"error":"user not found"}`, http.StatusNotFound)
return
}
// Find tenant for this user
var tenantID string
if userStore.db != nil {
userStore.db.QueryRow(`SELECT tenant_id FROM users WHERE id = ?`, user.ID).Scan(&tenantID)
}
// Issue JWT with tenant context
accessToken, err := Sign(Claims{
Sub: user.Email,
Role: user.Role,
TenantID: tenantID,
TokenType: "access",
Exp: time.Now().Add(15 * time.Minute).Unix(),
}, jwtSecret)
if err != nil {
http.Error(w, `{"error":"failed to issue token"}`, http.StatusInternalServerError)
return
}
refreshToken, _ := Sign(Claims{
Sub: user.Email,
Role: user.Role,
TenantID: tenantID,
TokenType: "refresh",
Exp: time.Now().Add(7 * 24 * time.Hour).Unix(),
}, jwtSecret)
var tenant *Tenant
if tenantID != "" {
tenant, _ = tenantStore.GetTenant(tenantID)
}
// SEC: H1 - Use httpOnly Cookies instead of returning JSON tokens
http.SetCookie(w, &http.Cookie{
Name: "syntrex_token",
Value: accessToken,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
MaxAge: 900,
})
http.SetCookie(w, &http.Cookie{
Name: "syntrex_refresh",
Value: refreshToken,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
MaxAge: 7 * 24 * 3600,
})
// SEC: M2 - Generate stateless CSRF token
csrfToken := hmacSign([]byte(accessToken), jwtSecret)[:32]
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"csrf_token": csrfToken,
"user": user,
"tenant": tenant,
})
}
}
// HandleGetTenant returns the current tenant info.
// GET /api/auth/tenant
func HandleGetTenant(tenantStore *TenantStore) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
claims := GetClaims(r.Context())
if claims == nil || claims.TenantID == "" {
http.Error(w, `{"error":"no tenant context"}`, http.StatusForbidden)
return
}
tenant, err := tenantStore.GetTenant(claims.TenantID)
if err != nil {
http.Error(w, `{"error":"tenant not found"}`, http.StatusNotFound)
return
}
plan := tenant.GetPlan()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"tenant": tenant,
"plan": plan,
"usage": map[string]interface{}{
"events_this_month": tenant.EventsThisMonth,
"events_limit": plan.MaxEventsMonth,
"usage_percent": usagePercent(tenant.EventsThisMonth, plan.MaxEventsMonth),
},
})
}
}
// HandleUpdateTenantPlan upgrades/downgrades the tenant plan.
// POST /api/auth/tenant/plan { plan_id }
func HandleUpdateTenantPlan(tenantStore *TenantStore) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
claims := GetClaims(r.Context())
if claims == nil || claims.Role != "admin" {
http.Error(w, `{"error":"admin role required"}`, http.StatusForbidden)
return
}
var req struct {
PlanID string `json:"plan_id"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, `{"error":"invalid request"}`, http.StatusBadRequest)
return
}
if err := tenantStore.UpdatePlan(claims.TenantID, req.PlanID); err != nil {
http.Error(w, `{"error":"`+err.Error()+`"}`, http.StatusBadRequest)
return
}
tenant, _ := tenantStore.GetTenant(claims.TenantID)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"tenant": tenant,
"plan": tenant.GetPlan(),
})
}
}
// HandleListPlans returns all available pricing plans.
// GET /api/auth/plans
func HandleListPlans() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
plans := make([]Plan, 0, len(DefaultPlans))
order := []string{"free", "starter", "professional", "enterprise"}
for _, id := range order {
if p, ok := DefaultPlans[id]; ok {
plans = append(plans, p)
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{"plans": plans})
}
}
// HandleBillingStatus returns the billing status for the tenant.
// GET /api/auth/billing
func HandleBillingStatus(tenantStore *TenantStore) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
claims := GetClaims(r.Context())
if claims == nil || claims.TenantID == "" {
http.Error(w, `{"error":"no tenant context"}`, http.StatusForbidden)
return
}
tenant, err := tenantStore.GetTenant(claims.TenantID)
if err != nil {
http.Error(w, `{"error":"tenant not found"}`, http.StatusNotFound)
return
}
plan := tenant.GetPlan()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"plan": plan,
"payment_customer_id": tenant.PaymentCustomerID,
"payment_sub_id": tenant.PaymentSubID,
"events_used": tenant.EventsThisMonth,
"events_limit": plan.MaxEventsMonth,
"usage_percent": usagePercent(tenant.EventsThisMonth, plan.MaxEventsMonth),
"next_reset": tenant.MonthResetAt,
})
}
}
// HandleStripeWebhook processes Stripe webhook events.
// POST /api/billing/webhook
func HandleStripeWebhook(tenantStore *TenantStore) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var evt struct {
Type string `json:"type"`
Data struct {
Object struct {
CustomerID string `json:"customer"`
SubscriptionID string `json:"id"`
Status string `json:"status"`
Metadata struct {
TenantID string `json:"tenant_id"`
PlanID string `json:"plan_id"`
} `json:"metadata"`
} `json:"object"`
} `json:"data"`
}
if err := json.NewDecoder(r.Body).Decode(&evt); err != nil {
http.Error(w, "invalid payload", http.StatusBadRequest)
return
}
tenantID := evt.Data.Object.Metadata.TenantID
switch evt.Type {
case "customer.subscription.created", "customer.subscription.updated":
if tenantID != "" {
tenantStore.SetStripeIDs(tenantID,
evt.Data.Object.CustomerID,
evt.Data.Object.SubscriptionID)
if planID := evt.Data.Object.Metadata.PlanID; planID != "" {
tenantStore.UpdatePlan(tenantID, planID)
}
}
case "customer.subscription.deleted":
if tenantID != "" {
tenantStore.UpdatePlan(tenantID, "starter")
tenantStore.SetStripeIDs(tenantID, evt.Data.Object.CustomerID, "")
}
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"received":true}`))
}
}
func usagePercent(used, limit int) float64 {
if limit <= 0 {
return 0
}
pct := float64(used) / float64(limit) * 100
if pct > 100 {
return 100
}
return pct
}