From 8d87c453b0effac431b35bcc203892877eb100fb Mon Sep 17 00:00:00 2001 From: DmitrL-dev <84296377+DmitrL-dev@users.noreply.github.com> Date: Tue, 24 Mar 2026 09:37:09 +1000 Subject: [PATCH] feat: add free starter plan with 1000 scans/month quota tracking --- cmd/soc/main.go | 7 + internal/infrastructure/auth/middleware.go | 1 + internal/infrastructure/auth/usage.go | 203 +++++++++++++++++++++ internal/transport/http/server.go | 8 + internal/transport/http/soc_handlers.go | 48 +++++ 5 files changed, 267 insertions(+) create mode 100644 internal/infrastructure/auth/usage.go diff --git a/cmd/soc/main.go b/cmd/soc/main.go index 6ac8ba4..eb9a969 100644 --- a/cmd/soc/main.go +++ b/cmd/soc/main.go @@ -24,6 +24,7 @@ import ( "github.com/syntrex/gomcp/internal/application/soc" socdomain "github.com/syntrex/gomcp/internal/domain/soc" "github.com/syntrex/gomcp/internal/domain/engines" + "github.com/syntrex/gomcp/internal/infrastructure/auth" "github.com/syntrex/gomcp/internal/infrastructure/audit" "github.com/syntrex/gomcp/internal/infrastructure/email" "github.com/syntrex/gomcp/internal/infrastructure/logging" @@ -137,6 +138,12 @@ func main() { logger.Info("JWT authentication configured") } + // Usage/quota tracking — metered free tier (1000 scans/month) + if db, ok := sqlDB.(*sql.DB); ok { + srv.SetUsageTracker(auth.NewUsageTracker(db)) + logger.Info("usage tracker initialized (free tier: 1000 scans/month)") + } + // Email service — Resend (set RESEND_API_KEY to enable real email delivery) if resendKey := env("RESEND_API_KEY", ""); resendKey != "" { fromAddr := env("EMAIL_FROM", "SYNTREX ") diff --git a/internal/infrastructure/auth/middleware.go b/internal/infrastructure/auth/middleware.go index 6d15dfb..40308c0 100644 --- a/internal/infrastructure/auth/middleware.go +++ b/internal/infrastructure/auth/middleware.go @@ -33,6 +33,7 @@ func NewJWTMiddleware(secret []byte) *JWTMiddleware { "/api/auth/verify": true, "/api/auth/plans": true, "/api/v1/scan": true, // public demo scanner + "/api/v1/usage": true, // public usage/quota check "/api/soc/events/stream": true, // SSE uses query param auth "/api/soc/stream": true, // SSE live feed (EventSource can't send headers) "/api/soc/ws": true, // WebSocket-style SSE push diff --git a/internal/infrastructure/auth/usage.go b/internal/infrastructure/auth/usage.go new file mode 100644 index 0000000..4d0a0dd --- /dev/null +++ b/internal/infrastructure/auth/usage.go @@ -0,0 +1,203 @@ +package auth + +import ( + "database/sql" + "fmt" + "log/slog" + "sync" + "time" +) + +// UsageInfo represents current usage state for a caller. +type UsageInfo struct { + Plan string `json:"plan"` + ScansUsed int `json:"scans_used"` + ScansLimit int `json:"scans_limit"` + Remaining int `json:"remaining"` + PeriodStart time.Time `json:"period_start"` + PeriodEnd time.Time `json:"period_end"` + Unlimited bool `json:"unlimited"` +} + +// UsageTracker tracks scan usage per user/IP with monthly quotas. +type UsageTracker struct { + mu sync.Mutex + db *sql.DB +} + +// NewUsageTracker creates a usage tracker backed by PostgreSQL. +func NewUsageTracker(db *sql.DB) *UsageTracker { + t := &UsageTracker{db: db} + if db != nil { + if err := t.migrate(); err != nil { + slog.Error("usage tracker: migration failed", "error", err) + } + // Reset expired quotas on startup + t.ResetExpired() + } + return t +} + +func (t *UsageTracker) migrate() error { + _, err := t.db.Exec(` + CREATE TABLE IF NOT EXISTS usage_quotas ( + id TEXT PRIMARY KEY, + user_id TEXT, + ip_addr TEXT, + plan TEXT NOT NULL DEFAULT 'free', + scans_used INTEGER NOT NULL DEFAULT 0, + scans_limit INTEGER NOT NULL DEFAULT 1000, + period_start TIMESTAMPTZ NOT NULL, + period_end TIMESTAMPTZ NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_usage_user ON usage_quotas(user_id) WHERE user_id IS NOT NULL; + CREATE INDEX IF NOT EXISTS idx_usage_ip ON usage_quotas(ip_addr) WHERE ip_addr IS NOT NULL; + `) + return err +} + +// currentPeriod returns the start and end of the current monthly billing period. +func currentPeriod() (time.Time, time.Time) { + now := time.Now().UTC() + start := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC) + end := start.AddDate(0, 1, 0) + return start, end +} + +// RecordScan atomically increments the scan counter and checks quota. +// Returns remaining scans. Returns error if quota exceeded. +func (t *UsageTracker) RecordScan(userID, ip string) (int, error) { + if t.db == nil { + return 999, nil // no DB = no limits + } + + t.mu.Lock() + defer t.mu.Unlock() + + periodStart, periodEnd := currentPeriod() + + // Determine lookup key + lookupCol := "ip_addr" + lookupVal := ip + if userID != "" { + lookupCol = "user_id" + lookupVal = userID + } + + // Try to get existing quota record for current period + var scansUsed, scansLimit int + var quotaID string + query := fmt.Sprintf( + `SELECT id, scans_used, scans_limit FROM usage_quotas + WHERE %s = $1 AND period_start = $2`, lookupCol) + + err := t.db.QueryRow(query, lookupVal, periodStart).Scan("aID, &scansUsed, &scansLimit) + + if err == sql.ErrNoRows { + // Create new quota record + quotaID = generateID("usg") + plan := "free" + limit := 1000 + var insertQuery string + if userID != "" { + insertQuery = `INSERT INTO usage_quotas (id, user_id, plan, scans_used, scans_limit, period_start, period_end) + VALUES ($1, $2, $3, 1, $4, $5, $6)` + } else { + insertQuery = `INSERT INTO usage_quotas (id, ip_addr, plan, scans_used, scans_limit, period_start, period_end) + VALUES ($1, $2, $3, 1, $4, $5, $6)` + } + _, err = t.db.Exec(insertQuery, quotaID, lookupVal, plan, limit, periodStart, periodEnd) + if err != nil { + slog.Error("usage: create quota", "error", err) + return 999, nil // fail open — don't block on DB errors + } + return limit - 1, nil + } + + if err != nil { + slog.Error("usage: query quota", "error", err) + return 999, nil // fail open + } + + // Unlimited plan (scans_limit = 0) + if scansLimit == 0 { + t.db.Exec(`UPDATE usage_quotas SET scans_used = scans_used + 1 WHERE id = $1`, quotaID) + return -1, nil // unlimited + } + + // Check quota + if scansUsed >= scansLimit { + return 0, fmt.Errorf("quota exceeded: %d/%d scans used this month", scansUsed, scansLimit) + } + + // Increment + _, err = t.db.Exec(`UPDATE usage_quotas SET scans_used = scans_used + 1 WHERE id = $1`, quotaID) + if err != nil { + slog.Error("usage: increment", "error", err) + } + + return scansLimit - scansUsed - 1, nil +} + +// GetUsage returns current usage for a user or IP. +func (t *UsageTracker) GetUsage(userID, ip string) *UsageInfo { + if t.db == nil { + return &UsageInfo{Plan: "free", ScansLimit: 1000, Remaining: 1000, Unlimited: false} + } + + periodStart, periodEnd := currentPeriod() + + lookupCol := "ip_addr" + lookupVal := ip + if userID != "" { + lookupCol = "user_id" + lookupVal = userID + } + + var info UsageInfo + query := fmt.Sprintf( + `SELECT plan, scans_used, scans_limit FROM usage_quotas + WHERE %s = $1 AND period_start = $2`, lookupCol) + + err := t.db.QueryRow(query, lookupVal, periodStart).Scan(&info.Plan, &info.ScansUsed, &info.ScansLimit) + if err != nil { + // No usage yet + return &UsageInfo{ + Plan: "free", + ScansUsed: 0, + ScansLimit: 1000, + Remaining: 1000, + PeriodStart: periodStart, + PeriodEnd: periodEnd, + } + } + + info.PeriodStart = periodStart + info.PeriodEnd = periodEnd + if info.ScansLimit == 0 { + info.Unlimited = true + info.Remaining = -1 + } else { + info.Remaining = info.ScansLimit - info.ScansUsed + if info.Remaining < 0 { + info.Remaining = 0 + } + } + + return &info +} + +// ResetExpired cleans up old quota records from previous periods. +func (t *UsageTracker) ResetExpired() { + if t.db == nil { + return + } + result, err := t.db.Exec(`DELETE FROM usage_quotas WHERE period_end < $1`, time.Now().UTC()) + if err != nil { + slog.Error("usage: reset expired", "error", err) + return + } + if n, _ := result.RowsAffected(); n > 0 { + slog.Info("usage: cleaned expired quotas", "count", n) + } +} diff --git a/internal/transport/http/server.go b/internal/transport/http/server.go index 46f44c4..35322a9 100644 --- a/internal/transport/http/server.go +++ b/internal/transport/http/server.go @@ -39,6 +39,7 @@ type Server struct { emailService *email.Service jwtSecret []byte wsHub *WSHub + usageTracker *auth.UsageTracker sovereignEnabled bool sovereignMode string pprofEnabled bool @@ -105,6 +106,11 @@ func (s *Server) SetJWTAuth(secret []byte, db ...*sql.DB) { slog.Info("JWT authentication enabled") } +// SetUsageTracker sets the usage/quota tracker for scan metering. +func (s *Server) SetUsageTracker(tracker *auth.UsageTracker) { + s.usageTracker = tracker +} + // SetRBAC configures RBAC middleware with API key authentication (§17). func (s *Server) SetRBAC(rbac *RBACMiddleware) { s.rbac = rbac @@ -256,6 +262,8 @@ func (s *Server) Start(ctx context.Context) error { // Public scan endpoint — demo scanner (no auth required, rate-limited) mux.HandleFunc("POST /api/v1/scan", s.handlePublicScan) + // Usage endpoint — returns scan quota for caller + mux.HandleFunc("GET /api/v1/usage", s.handleUsage) // pprof debug endpoints (§P4C) — gated behind EnablePprof() if s.pprofEnabled { diff --git a/internal/transport/http/soc_handlers.go b/internal/transport/http/soc_handlers.go index d201ac2..66f7fcf 100644 --- a/internal/transport/http/soc_handlers.go +++ b/internal/transport/http/soc_handlers.go @@ -1475,6 +1475,27 @@ func (s *Server) handlePublicScan(w http.ResponseWriter, r *http.Request) { return } + // Check usage quota (free tier: 1000 scans/month) + if s.usageTracker != nil { + userID := "" + if claims := auth.GetClaims(r.Context()); claims != nil { + userID = claims.Sub + } + ip := r.RemoteAddr + if fwd := r.Header.Get("X-Forwarded-For"); fwd != "" { + ip = fwd + } + remaining, err := s.usageTracker.RecordScan(userID, ip) + if err != nil { + w.Header().Set("X-RateLimit-Remaining", "0") + writeError(w, http.StatusTooManyRequests, "monthly scan quota exceeded — upgrade your plan at syntrex.pro/pricing") + return + } + if remaining >= 0 { + w.Header().Set("X-RateLimit-Remaining", strconv.Itoa(remaining)) + } + } + // Run sentinel-core (54 Rust engines) coreEngine := s.getEngine("sentinel-core") coreResult, coreErr := coreEngine.ScanPrompt(r.Context(), req.Prompt) @@ -1533,3 +1554,30 @@ func (s *Server) handlePublicScan(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, response) } + +// handleUsage returns current scan usage and quota for the caller. +// GET /api/v1/usage +func (s *Server) handleUsage(w http.ResponseWriter, r *http.Request) { + if s.usageTracker == nil { + writeJSON(w, http.StatusOK, map[string]any{ + "plan": "free", + "scans_used": 0, + "scans_limit": 1000, + "remaining": 1000, + "unlimited": false, + }) + return + } + + userID := "" + if claims := auth.GetClaims(r.Context()); claims != nil { + userID = claims.Sub + } + ip := r.RemoteAddr + if fwd := r.Header.Get("X-Forwarded-For"); fwd != "" { + ip = fwd + } + + info := s.usageTracker.GetUsage(userID, ip) + writeJSON(w, http.StatusOK, info) +}