mirror of
https://github.com/syntrex-lab/gomcp.git
synced 2026-04-24 20:06:21 +02:00
feat: add free starter plan with 1000 scans/month quota tracking
This commit is contained in:
parent
f581d65951
commit
8d87c453b0
5 changed files with 267 additions and 0 deletions
|
|
@ -24,6 +24,7 @@ import (
|
|||
"github.com/syntrex/gomcp/internal/application/soc"
|
||||
socdomain "github.com/syntrex/gomcp/internal/domain/soc"
|
||||
"github.com/syntrex/gomcp/internal/domain/engines"
|
||||
"github.com/syntrex/gomcp/internal/infrastructure/auth"
|
||||
"github.com/syntrex/gomcp/internal/infrastructure/audit"
|
||||
"github.com/syntrex/gomcp/internal/infrastructure/email"
|
||||
"github.com/syntrex/gomcp/internal/infrastructure/logging"
|
||||
|
|
@ -137,6 +138,12 @@ func main() {
|
|||
logger.Info("JWT authentication configured")
|
||||
}
|
||||
|
||||
// Usage/quota tracking — metered free tier (1000 scans/month)
|
||||
if db, ok := sqlDB.(*sql.DB); ok {
|
||||
srv.SetUsageTracker(auth.NewUsageTracker(db))
|
||||
logger.Info("usage tracker initialized (free tier: 1000 scans/month)")
|
||||
}
|
||||
|
||||
// Email service — Resend (set RESEND_API_KEY to enable real email delivery)
|
||||
if resendKey := env("RESEND_API_KEY", ""); resendKey != "" {
|
||||
fromAddr := env("EMAIL_FROM", "SYNTREX <noreply@xn--80akacl3adqr.xn--p1acf>")
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ func NewJWTMiddleware(secret []byte) *JWTMiddleware {
|
|||
"/api/auth/verify": true,
|
||||
"/api/auth/plans": true,
|
||||
"/api/v1/scan": true, // public demo scanner
|
||||
"/api/v1/usage": true, // public usage/quota check
|
||||
"/api/soc/events/stream": true, // SSE uses query param auth
|
||||
"/api/soc/stream": true, // SSE live feed (EventSource can't send headers)
|
||||
"/api/soc/ws": true, // WebSocket-style SSE push
|
||||
|
|
|
|||
203
internal/infrastructure/auth/usage.go
Normal file
203
internal/infrastructure/auth/usage.go
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// UsageInfo represents current usage state for a caller.
|
||||
type UsageInfo struct {
|
||||
Plan string `json:"plan"`
|
||||
ScansUsed int `json:"scans_used"`
|
||||
ScansLimit int `json:"scans_limit"`
|
||||
Remaining int `json:"remaining"`
|
||||
PeriodStart time.Time `json:"period_start"`
|
||||
PeriodEnd time.Time `json:"period_end"`
|
||||
Unlimited bool `json:"unlimited"`
|
||||
}
|
||||
|
||||
// UsageTracker tracks scan usage per user/IP with monthly quotas.
|
||||
type UsageTracker struct {
|
||||
mu sync.Mutex
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewUsageTracker creates a usage tracker backed by PostgreSQL.
|
||||
func NewUsageTracker(db *sql.DB) *UsageTracker {
|
||||
t := &UsageTracker{db: db}
|
||||
if db != nil {
|
||||
if err := t.migrate(); err != nil {
|
||||
slog.Error("usage tracker: migration failed", "error", err)
|
||||
}
|
||||
// Reset expired quotas on startup
|
||||
t.ResetExpired()
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *UsageTracker) migrate() error {
|
||||
_, err := t.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS usage_quotas (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT,
|
||||
ip_addr TEXT,
|
||||
plan TEXT NOT NULL DEFAULT 'free',
|
||||
scans_used INTEGER NOT NULL DEFAULT 0,
|
||||
scans_limit INTEGER NOT NULL DEFAULT 1000,
|
||||
period_start TIMESTAMPTZ NOT NULL,
|
||||
period_end TIMESTAMPTZ NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_user ON usage_quotas(user_id) WHERE user_id IS NOT NULL;
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_ip ON usage_quotas(ip_addr) WHERE ip_addr IS NOT NULL;
|
||||
`)
|
||||
return err
|
||||
}
|
||||
|
||||
// currentPeriod returns the start and end of the current monthly billing period.
|
||||
func currentPeriod() (time.Time, time.Time) {
|
||||
now := time.Now().UTC()
|
||||
start := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.AddDate(0, 1, 0)
|
||||
return start, end
|
||||
}
|
||||
|
||||
// RecordScan atomically increments the scan counter and checks quota.
|
||||
// Returns remaining scans. Returns error if quota exceeded.
|
||||
func (t *UsageTracker) RecordScan(userID, ip string) (int, error) {
|
||||
if t.db == nil {
|
||||
return 999, nil // no DB = no limits
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
periodStart, periodEnd := currentPeriod()
|
||||
|
||||
// Determine lookup key
|
||||
lookupCol := "ip_addr"
|
||||
lookupVal := ip
|
||||
if userID != "" {
|
||||
lookupCol = "user_id"
|
||||
lookupVal = userID
|
||||
}
|
||||
|
||||
// Try to get existing quota record for current period
|
||||
var scansUsed, scansLimit int
|
||||
var quotaID string
|
||||
query := fmt.Sprintf(
|
||||
`SELECT id, scans_used, scans_limit FROM usage_quotas
|
||||
WHERE %s = $1 AND period_start = $2`, lookupCol)
|
||||
|
||||
err := t.db.QueryRow(query, lookupVal, periodStart).Scan("aID, &scansUsed, &scansLimit)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
// Create new quota record
|
||||
quotaID = generateID("usg")
|
||||
plan := "free"
|
||||
limit := 1000
|
||||
var insertQuery string
|
||||
if userID != "" {
|
||||
insertQuery = `INSERT INTO usage_quotas (id, user_id, plan, scans_used, scans_limit, period_start, period_end)
|
||||
VALUES ($1, $2, $3, 1, $4, $5, $6)`
|
||||
} else {
|
||||
insertQuery = `INSERT INTO usage_quotas (id, ip_addr, plan, scans_used, scans_limit, period_start, period_end)
|
||||
VALUES ($1, $2, $3, 1, $4, $5, $6)`
|
||||
}
|
||||
_, err = t.db.Exec(insertQuery, quotaID, lookupVal, plan, limit, periodStart, periodEnd)
|
||||
if err != nil {
|
||||
slog.Error("usage: create quota", "error", err)
|
||||
return 999, nil // fail open — don't block on DB errors
|
||||
}
|
||||
return limit - 1, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
slog.Error("usage: query quota", "error", err)
|
||||
return 999, nil // fail open
|
||||
}
|
||||
|
||||
// Unlimited plan (scans_limit = 0)
|
||||
if scansLimit == 0 {
|
||||
t.db.Exec(`UPDATE usage_quotas SET scans_used = scans_used + 1 WHERE id = $1`, quotaID)
|
||||
return -1, nil // unlimited
|
||||
}
|
||||
|
||||
// Check quota
|
||||
if scansUsed >= scansLimit {
|
||||
return 0, fmt.Errorf("quota exceeded: %d/%d scans used this month", scansUsed, scansLimit)
|
||||
}
|
||||
|
||||
// Increment
|
||||
_, err = t.db.Exec(`UPDATE usage_quotas SET scans_used = scans_used + 1 WHERE id = $1`, quotaID)
|
||||
if err != nil {
|
||||
slog.Error("usage: increment", "error", err)
|
||||
}
|
||||
|
||||
return scansLimit - scansUsed - 1, nil
|
||||
}
|
||||
|
||||
// GetUsage returns current usage for a user or IP.
|
||||
func (t *UsageTracker) GetUsage(userID, ip string) *UsageInfo {
|
||||
if t.db == nil {
|
||||
return &UsageInfo{Plan: "free", ScansLimit: 1000, Remaining: 1000, Unlimited: false}
|
||||
}
|
||||
|
||||
periodStart, periodEnd := currentPeriod()
|
||||
|
||||
lookupCol := "ip_addr"
|
||||
lookupVal := ip
|
||||
if userID != "" {
|
||||
lookupCol = "user_id"
|
||||
lookupVal = userID
|
||||
}
|
||||
|
||||
var info UsageInfo
|
||||
query := fmt.Sprintf(
|
||||
`SELECT plan, scans_used, scans_limit FROM usage_quotas
|
||||
WHERE %s = $1 AND period_start = $2`, lookupCol)
|
||||
|
||||
err := t.db.QueryRow(query, lookupVal, periodStart).Scan(&info.Plan, &info.ScansUsed, &info.ScansLimit)
|
||||
if err != nil {
|
||||
// No usage yet
|
||||
return &UsageInfo{
|
||||
Plan: "free",
|
||||
ScansUsed: 0,
|
||||
ScansLimit: 1000,
|
||||
Remaining: 1000,
|
||||
PeriodStart: periodStart,
|
||||
PeriodEnd: periodEnd,
|
||||
}
|
||||
}
|
||||
|
||||
info.PeriodStart = periodStart
|
||||
info.PeriodEnd = periodEnd
|
||||
if info.ScansLimit == 0 {
|
||||
info.Unlimited = true
|
||||
info.Remaining = -1
|
||||
} else {
|
||||
info.Remaining = info.ScansLimit - info.ScansUsed
|
||||
if info.Remaining < 0 {
|
||||
info.Remaining = 0
|
||||
}
|
||||
}
|
||||
|
||||
return &info
|
||||
}
|
||||
|
||||
// ResetExpired cleans up old quota records from previous periods.
|
||||
func (t *UsageTracker) ResetExpired() {
|
||||
if t.db == nil {
|
||||
return
|
||||
}
|
||||
result, err := t.db.Exec(`DELETE FROM usage_quotas WHERE period_end < $1`, time.Now().UTC())
|
||||
if err != nil {
|
||||
slog.Error("usage: reset expired", "error", err)
|
||||
return
|
||||
}
|
||||
if n, _ := result.RowsAffected(); n > 0 {
|
||||
slog.Info("usage: cleaned expired quotas", "count", n)
|
||||
}
|
||||
}
|
||||
|
|
@ -39,6 +39,7 @@ type Server struct {
|
|||
emailService *email.Service
|
||||
jwtSecret []byte
|
||||
wsHub *WSHub
|
||||
usageTracker *auth.UsageTracker
|
||||
sovereignEnabled bool
|
||||
sovereignMode string
|
||||
pprofEnabled bool
|
||||
|
|
@ -105,6 +106,11 @@ func (s *Server) SetJWTAuth(secret []byte, db ...*sql.DB) {
|
|||
slog.Info("JWT authentication enabled")
|
||||
}
|
||||
|
||||
// SetUsageTracker sets the usage/quota tracker for scan metering.
|
||||
func (s *Server) SetUsageTracker(tracker *auth.UsageTracker) {
|
||||
s.usageTracker = tracker
|
||||
}
|
||||
|
||||
// SetRBAC configures RBAC middleware with API key authentication (§17).
|
||||
func (s *Server) SetRBAC(rbac *RBACMiddleware) {
|
||||
s.rbac = rbac
|
||||
|
|
@ -256,6 +262,8 @@ func (s *Server) Start(ctx context.Context) error {
|
|||
|
||||
// Public scan endpoint — demo scanner (no auth required, rate-limited)
|
||||
mux.HandleFunc("POST /api/v1/scan", s.handlePublicScan)
|
||||
// Usage endpoint — returns scan quota for caller
|
||||
mux.HandleFunc("GET /api/v1/usage", s.handleUsage)
|
||||
|
||||
// pprof debug endpoints (§P4C) — gated behind EnablePprof()
|
||||
if s.pprofEnabled {
|
||||
|
|
|
|||
|
|
@ -1475,6 +1475,27 @@ func (s *Server) handlePublicScan(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
// Check usage quota (free tier: 1000 scans/month)
|
||||
if s.usageTracker != nil {
|
||||
userID := ""
|
||||
if claims := auth.GetClaims(r.Context()); claims != nil {
|
||||
userID = claims.Sub
|
||||
}
|
||||
ip := r.RemoteAddr
|
||||
if fwd := r.Header.Get("X-Forwarded-For"); fwd != "" {
|
||||
ip = fwd
|
||||
}
|
||||
remaining, err := s.usageTracker.RecordScan(userID, ip)
|
||||
if err != nil {
|
||||
w.Header().Set("X-RateLimit-Remaining", "0")
|
||||
writeError(w, http.StatusTooManyRequests, "monthly scan quota exceeded — upgrade your plan at syntrex.pro/pricing")
|
||||
return
|
||||
}
|
||||
if remaining >= 0 {
|
||||
w.Header().Set("X-RateLimit-Remaining", strconv.Itoa(remaining))
|
||||
}
|
||||
}
|
||||
|
||||
// Run sentinel-core (54 Rust engines)
|
||||
coreEngine := s.getEngine("sentinel-core")
|
||||
coreResult, coreErr := coreEngine.ScanPrompt(r.Context(), req.Prompt)
|
||||
|
|
@ -1533,3 +1554,30 @@ func (s *Server) handlePublicScan(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
writeJSON(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
// handleUsage returns current scan usage and quota for the caller.
|
||||
// GET /api/v1/usage
|
||||
func (s *Server) handleUsage(w http.ResponseWriter, r *http.Request) {
|
||||
if s.usageTracker == nil {
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"plan": "free",
|
||||
"scans_used": 0,
|
||||
"scans_limit": 1000,
|
||||
"remaining": 1000,
|
||||
"unlimited": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userID := ""
|
||||
if claims := auth.GetClaims(r.Context()); claims != nil {
|
||||
userID = claims.Sub
|
||||
}
|
||||
ip := r.RemoteAddr
|
||||
if fwd := r.Header.Get("X-Forwarded-For"); fwd != "" {
|
||||
ip = fwd
|
||||
}
|
||||
|
||||
info := s.usageTracker.GetUsage(userID, ip)
|
||||
writeJSON(w, http.StatusOK, info)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue