mirror of
https://github.com/syntrex-lab/gomcp.git
synced 2026-04-24 20:06:21 +02:00
203 lines
5.4 KiB
Go
203 lines
5.4 KiB
Go
package auth
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"log/slog"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// UsageInfo represents current usage state for a caller.
|
|
type UsageInfo struct {
|
|
Plan string `json:"plan"`
|
|
ScansUsed int `json:"scans_used"`
|
|
ScansLimit int `json:"scans_limit"`
|
|
Remaining int `json:"remaining"`
|
|
PeriodStart time.Time `json:"period_start"`
|
|
PeriodEnd time.Time `json:"period_end"`
|
|
Unlimited bool `json:"unlimited"`
|
|
}
|
|
|
|
// UsageTracker tracks scan usage per user/IP with monthly quotas.
|
|
type UsageTracker struct {
|
|
mu sync.Mutex
|
|
db *sql.DB
|
|
}
|
|
|
|
// NewUsageTracker creates a usage tracker backed by PostgreSQL.
|
|
func NewUsageTracker(db *sql.DB) *UsageTracker {
|
|
t := &UsageTracker{db: db}
|
|
if db != nil {
|
|
if err := t.migrate(); err != nil {
|
|
slog.Error("usage tracker: migration failed", "error", err)
|
|
}
|
|
// Reset expired quotas on startup
|
|
t.ResetExpired()
|
|
}
|
|
return t
|
|
}
|
|
|
|
func (t *UsageTracker) migrate() error {
|
|
_, err := t.db.Exec(`
|
|
CREATE TABLE IF NOT EXISTS usage_quotas (
|
|
id TEXT PRIMARY KEY,
|
|
user_id TEXT,
|
|
ip_addr TEXT,
|
|
plan TEXT NOT NULL DEFAULT 'free',
|
|
scans_used INTEGER NOT NULL DEFAULT 0,
|
|
scans_limit INTEGER NOT NULL DEFAULT 1000,
|
|
period_start TIMESTAMPTZ NOT NULL,
|
|
period_end TIMESTAMPTZ NOT NULL
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_usage_user ON usage_quotas(user_id) WHERE user_id IS NOT NULL;
|
|
CREATE INDEX IF NOT EXISTS idx_usage_ip ON usage_quotas(ip_addr) WHERE ip_addr IS NOT NULL;
|
|
`)
|
|
return err
|
|
}
|
|
|
|
// currentPeriod returns the start and end of the current monthly billing period.
|
|
func currentPeriod() (time.Time, time.Time) {
|
|
now := time.Now().UTC()
|
|
start := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC)
|
|
end := start.AddDate(0, 1, 0)
|
|
return start, end
|
|
}
|
|
|
|
// RecordScan atomically increments the scan counter and checks quota.
|
|
// Returns remaining scans. Returns error if quota exceeded.
|
|
func (t *UsageTracker) RecordScan(userID, ip string) (int, error) {
|
|
if t.db == nil {
|
|
return 999, nil // no DB = no limits
|
|
}
|
|
|
|
t.mu.Lock()
|
|
defer t.mu.Unlock()
|
|
|
|
periodStart, periodEnd := currentPeriod()
|
|
|
|
// Determine lookup key
|
|
lookupCol := "ip_addr"
|
|
lookupVal := ip
|
|
if userID != "" {
|
|
lookupCol = "user_id"
|
|
lookupVal = userID
|
|
}
|
|
|
|
// Try to get existing quota record for current period
|
|
var scansUsed, scansLimit int
|
|
var quotaID string
|
|
query := fmt.Sprintf(
|
|
`SELECT id, scans_used, scans_limit FROM usage_quotas
|
|
WHERE %s = $1 AND period_start = $2`, lookupCol)
|
|
|
|
err := t.db.QueryRow(query, lookupVal, periodStart).Scan("aID, &scansUsed, &scansLimit)
|
|
|
|
if err == sql.ErrNoRows {
|
|
// Create new quota record
|
|
quotaID = generateID("usg")
|
|
plan := "free"
|
|
limit := 1000
|
|
var insertQuery string
|
|
if userID != "" {
|
|
insertQuery = `INSERT INTO usage_quotas (id, user_id, plan, scans_used, scans_limit, period_start, period_end)
|
|
VALUES ($1, $2, $3, 1, $4, $5, $6)`
|
|
} else {
|
|
insertQuery = `INSERT INTO usage_quotas (id, ip_addr, plan, scans_used, scans_limit, period_start, period_end)
|
|
VALUES ($1, $2, $3, 1, $4, $5, $6)`
|
|
}
|
|
_, err = t.db.Exec(insertQuery, quotaID, lookupVal, plan, limit, periodStart, periodEnd)
|
|
if err != nil {
|
|
slog.Error("usage: create quota", "error", err)
|
|
return 999, nil // fail open — don't block on DB errors
|
|
}
|
|
return limit - 1, nil
|
|
}
|
|
|
|
if err != nil {
|
|
slog.Error("usage: query quota", "error", err)
|
|
return 999, nil // fail open
|
|
}
|
|
|
|
// Unlimited plan (scans_limit = 0)
|
|
if scansLimit == 0 {
|
|
t.db.Exec(`UPDATE usage_quotas SET scans_used = scans_used + 1 WHERE id = $1`, quotaID)
|
|
return -1, nil // unlimited
|
|
}
|
|
|
|
// Check quota
|
|
if scansUsed >= scansLimit {
|
|
return 0, fmt.Errorf("quota exceeded: %d/%d scans used this month", scansUsed, scansLimit)
|
|
}
|
|
|
|
// Increment
|
|
_, err = t.db.Exec(`UPDATE usage_quotas SET scans_used = scans_used + 1 WHERE id = $1`, quotaID)
|
|
if err != nil {
|
|
slog.Error("usage: increment", "error", err)
|
|
}
|
|
|
|
return scansLimit - scansUsed - 1, nil
|
|
}
|
|
|
|
// GetUsage returns current usage for a user or IP.
|
|
func (t *UsageTracker) GetUsage(userID, ip string) *UsageInfo {
|
|
if t.db == nil {
|
|
return &UsageInfo{Plan: "free", ScansLimit: 1000, Remaining: 1000, Unlimited: false}
|
|
}
|
|
|
|
periodStart, periodEnd := currentPeriod()
|
|
|
|
lookupCol := "ip_addr"
|
|
lookupVal := ip
|
|
if userID != "" {
|
|
lookupCol = "user_id"
|
|
lookupVal = userID
|
|
}
|
|
|
|
var info UsageInfo
|
|
query := fmt.Sprintf(
|
|
`SELECT plan, scans_used, scans_limit FROM usage_quotas
|
|
WHERE %s = $1 AND period_start = $2`, lookupCol)
|
|
|
|
err := t.db.QueryRow(query, lookupVal, periodStart).Scan(&info.Plan, &info.ScansUsed, &info.ScansLimit)
|
|
if err != nil {
|
|
// No usage yet
|
|
return &UsageInfo{
|
|
Plan: "free",
|
|
ScansUsed: 0,
|
|
ScansLimit: 1000,
|
|
Remaining: 1000,
|
|
PeriodStart: periodStart,
|
|
PeriodEnd: periodEnd,
|
|
}
|
|
}
|
|
|
|
info.PeriodStart = periodStart
|
|
info.PeriodEnd = periodEnd
|
|
if info.ScansLimit == 0 {
|
|
info.Unlimited = true
|
|
info.Remaining = -1
|
|
} else {
|
|
info.Remaining = info.ScansLimit - info.ScansUsed
|
|
if info.Remaining < 0 {
|
|
info.Remaining = 0
|
|
}
|
|
}
|
|
|
|
return &info
|
|
}
|
|
|
|
// ResetExpired cleans up old quota records from previous periods.
|
|
func (t *UsageTracker) ResetExpired() {
|
|
if t.db == nil {
|
|
return
|
|
}
|
|
result, err := t.db.Exec(`DELETE FROM usage_quotas WHERE period_end < $1`, time.Now().UTC())
|
|
if err != nil {
|
|
slog.Error("usage: reset expired", "error", err)
|
|
return
|
|
}
|
|
if n, _ := result.RowsAffected(); n > 0 {
|
|
slog.Info("usage: cleaned expired quotas", "count", n)
|
|
}
|
|
}
|