gomcp/internal/infrastructure/auth/rate_limiter.go

119 lines
2.5 KiB
Go

package auth
import (
"net/http"
"sync"
"time"
)
// RateLimiter tracks login attempts per IP using a sliding window.
type RateLimiter struct {
mu sync.Mutex
attempts map[string]*ipBucket
maxHits int
window time.Duration
cleanup time.Duration
}
type ipBucket struct {
timestamps []time.Time
}
// NewRateLimiter creates a rate limiter.
// maxHits: max attempts per window per IP.
// window: sliding window duration.
func NewRateLimiter(maxHits int, window time.Duration) *RateLimiter {
rl := &RateLimiter{
attempts: make(map[string]*ipBucket),
maxHits: maxHits,
window: window,
cleanup: 5 * time.Minute,
}
go rl.cleanupLoop()
return rl
}
// Allow checks if the IP is within the rate limit.
// Returns true if allowed, false if rate-limited.
func (rl *RateLimiter) Allow(ip string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
bucket, ok := rl.attempts[ip]
if !ok {
bucket = &ipBucket{}
rl.attempts[ip] = bucket
}
// Prune old timestamps outside the window.
cutoff := now.Add(-rl.window)
valid := bucket.timestamps[:0]
for _, t := range bucket.timestamps {
if t.After(cutoff) {
valid = append(valid, t)
}
}
bucket.timestamps = valid
if len(bucket.timestamps) >= rl.maxHits {
return false
}
bucket.timestamps = append(bucket.timestamps, now)
return true
}
// Reset clears attempts for an IP (e.g., on successful login).
func (rl *RateLimiter) Reset(ip string) {
rl.mu.Lock()
defer rl.mu.Unlock()
delete(rl.attempts, ip)
}
func (rl *RateLimiter) cleanupLoop() {
ticker := time.NewTicker(rl.cleanup)
defer ticker.Stop()
for range ticker.C {
rl.mu.Lock()
now := time.Now()
cutoff := now.Add(-rl.window)
for ip, bucket := range rl.attempts {
valid := bucket.timestamps[:0]
for _, t := range bucket.timestamps {
if t.After(cutoff) {
valid = append(valid, t)
}
}
if len(valid) == 0 {
delete(rl.attempts, ip)
} else {
bucket.timestamps = valid
}
}
rl.mu.Unlock()
}
}
// RateLimitMiddleware wraps an http.HandlerFunc with rate limiting.
func RateLimitMiddleware(rl *RateLimiter, next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ip := r.RemoteAddr
// Strip port if present.
if idx := len(ip) - 1; idx > 0 {
for i := idx; i >= 0; i-- {
if ip[i] == ':' {
ip = ip[:i]
break
}
}
}
if !rl.Allow(ip) {
w.Header().Set("Retry-After", "60")
writeAuthError(w, http.StatusTooManyRequests, "rate limit exceeded — try again later")
return
}
next(w, r)
}
}