mirror of
https://github.com/syntrex-lab/gomcp.git
synced 2026-04-25 04:16:22 +02:00
feat: add free starter plan with 1000 scans/month quota tracking
This commit is contained in:
parent
f581d65951
commit
8d87c453b0
5 changed files with 267 additions and 0 deletions
203
internal/infrastructure/auth/usage.go
Normal file
203
internal/infrastructure/auth/usage.go
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// UsageInfo represents current usage state for a caller.
|
||||
type UsageInfo struct {
|
||||
Plan string `json:"plan"`
|
||||
ScansUsed int `json:"scans_used"`
|
||||
ScansLimit int `json:"scans_limit"`
|
||||
Remaining int `json:"remaining"`
|
||||
PeriodStart time.Time `json:"period_start"`
|
||||
PeriodEnd time.Time `json:"period_end"`
|
||||
Unlimited bool `json:"unlimited"`
|
||||
}
|
||||
|
||||
// UsageTracker tracks scan usage per user/IP with monthly quotas.
|
||||
type UsageTracker struct {
|
||||
mu sync.Mutex
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewUsageTracker creates a usage tracker backed by PostgreSQL.
|
||||
func NewUsageTracker(db *sql.DB) *UsageTracker {
|
||||
t := &UsageTracker{db: db}
|
||||
if db != nil {
|
||||
if err := t.migrate(); err != nil {
|
||||
slog.Error("usage tracker: migration failed", "error", err)
|
||||
}
|
||||
// Reset expired quotas on startup
|
||||
t.ResetExpired()
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *UsageTracker) migrate() error {
|
||||
_, err := t.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS usage_quotas (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT,
|
||||
ip_addr TEXT,
|
||||
plan TEXT NOT NULL DEFAULT 'free',
|
||||
scans_used INTEGER NOT NULL DEFAULT 0,
|
||||
scans_limit INTEGER NOT NULL DEFAULT 1000,
|
||||
period_start TIMESTAMPTZ NOT NULL,
|
||||
period_end TIMESTAMPTZ NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_user ON usage_quotas(user_id) WHERE user_id IS NOT NULL;
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_ip ON usage_quotas(ip_addr) WHERE ip_addr IS NOT NULL;
|
||||
`)
|
||||
return err
|
||||
}
|
||||
|
||||
// currentPeriod returns the start and end of the current monthly billing period.
|
||||
func currentPeriod() (time.Time, time.Time) {
|
||||
now := time.Now().UTC()
|
||||
start := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.AddDate(0, 1, 0)
|
||||
return start, end
|
||||
}
|
||||
|
||||
// RecordScan atomically increments the scan counter and checks quota.
|
||||
// Returns remaining scans. Returns error if quota exceeded.
|
||||
func (t *UsageTracker) RecordScan(userID, ip string) (int, error) {
|
||||
if t.db == nil {
|
||||
return 999, nil // no DB = no limits
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
periodStart, periodEnd := currentPeriod()
|
||||
|
||||
// Determine lookup key
|
||||
lookupCol := "ip_addr"
|
||||
lookupVal := ip
|
||||
if userID != "" {
|
||||
lookupCol = "user_id"
|
||||
lookupVal = userID
|
||||
}
|
||||
|
||||
// Try to get existing quota record for current period
|
||||
var scansUsed, scansLimit int
|
||||
var quotaID string
|
||||
query := fmt.Sprintf(
|
||||
`SELECT id, scans_used, scans_limit FROM usage_quotas
|
||||
WHERE %s = $1 AND period_start = $2`, lookupCol)
|
||||
|
||||
err := t.db.QueryRow(query, lookupVal, periodStart).Scan("aID, &scansUsed, &scansLimit)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
// Create new quota record
|
||||
quotaID = generateID("usg")
|
||||
plan := "free"
|
||||
limit := 1000
|
||||
var insertQuery string
|
||||
if userID != "" {
|
||||
insertQuery = `INSERT INTO usage_quotas (id, user_id, plan, scans_used, scans_limit, period_start, period_end)
|
||||
VALUES ($1, $2, $3, 1, $4, $5, $6)`
|
||||
} else {
|
||||
insertQuery = `INSERT INTO usage_quotas (id, ip_addr, plan, scans_used, scans_limit, period_start, period_end)
|
||||
VALUES ($1, $2, $3, 1, $4, $5, $6)`
|
||||
}
|
||||
_, err = t.db.Exec(insertQuery, quotaID, lookupVal, plan, limit, periodStart, periodEnd)
|
||||
if err != nil {
|
||||
slog.Error("usage: create quota", "error", err)
|
||||
return 999, nil // fail open — don't block on DB errors
|
||||
}
|
||||
return limit - 1, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
slog.Error("usage: query quota", "error", err)
|
||||
return 999, nil // fail open
|
||||
}
|
||||
|
||||
// Unlimited plan (scans_limit = 0)
|
||||
if scansLimit == 0 {
|
||||
t.db.Exec(`UPDATE usage_quotas SET scans_used = scans_used + 1 WHERE id = $1`, quotaID)
|
||||
return -1, nil // unlimited
|
||||
}
|
||||
|
||||
// Check quota
|
||||
if scansUsed >= scansLimit {
|
||||
return 0, fmt.Errorf("quota exceeded: %d/%d scans used this month", scansUsed, scansLimit)
|
||||
}
|
||||
|
||||
// Increment
|
||||
_, err = t.db.Exec(`UPDATE usage_quotas SET scans_used = scans_used + 1 WHERE id = $1`, quotaID)
|
||||
if err != nil {
|
||||
slog.Error("usage: increment", "error", err)
|
||||
}
|
||||
|
||||
return scansLimit - scansUsed - 1, nil
|
||||
}
|
||||
|
||||
// GetUsage returns current usage for a user or IP.
|
||||
func (t *UsageTracker) GetUsage(userID, ip string) *UsageInfo {
|
||||
if t.db == nil {
|
||||
return &UsageInfo{Plan: "free", ScansLimit: 1000, Remaining: 1000, Unlimited: false}
|
||||
}
|
||||
|
||||
periodStart, periodEnd := currentPeriod()
|
||||
|
||||
lookupCol := "ip_addr"
|
||||
lookupVal := ip
|
||||
if userID != "" {
|
||||
lookupCol = "user_id"
|
||||
lookupVal = userID
|
||||
}
|
||||
|
||||
var info UsageInfo
|
||||
query := fmt.Sprintf(
|
||||
`SELECT plan, scans_used, scans_limit FROM usage_quotas
|
||||
WHERE %s = $1 AND period_start = $2`, lookupCol)
|
||||
|
||||
err := t.db.QueryRow(query, lookupVal, periodStart).Scan(&info.Plan, &info.ScansUsed, &info.ScansLimit)
|
||||
if err != nil {
|
||||
// No usage yet
|
||||
return &UsageInfo{
|
||||
Plan: "free",
|
||||
ScansUsed: 0,
|
||||
ScansLimit: 1000,
|
||||
Remaining: 1000,
|
||||
PeriodStart: periodStart,
|
||||
PeriodEnd: periodEnd,
|
||||
}
|
||||
}
|
||||
|
||||
info.PeriodStart = periodStart
|
||||
info.PeriodEnd = periodEnd
|
||||
if info.ScansLimit == 0 {
|
||||
info.Unlimited = true
|
||||
info.Remaining = -1
|
||||
} else {
|
||||
info.Remaining = info.ScansLimit - info.ScansUsed
|
||||
if info.Remaining < 0 {
|
||||
info.Remaining = 0
|
||||
}
|
||||
}
|
||||
|
||||
return &info
|
||||
}
|
||||
|
||||
// ResetExpired cleans up old quota records from previous periods.
|
||||
func (t *UsageTracker) ResetExpired() {
|
||||
if t.db == nil {
|
||||
return
|
||||
}
|
||||
result, err := t.db.Exec(`DELETE FROM usage_quotas WHERE period_end < $1`, time.Now().UTC())
|
||||
if err != nil {
|
||||
slog.Error("usage: reset expired", "error", err)
|
||||
return
|
||||
}
|
||||
if n, _ := result.RowsAffected(); n > 0 {
|
||||
slog.Info("usage: cleaned expired quotas", "count", n)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue