gomcp/internal/infrastructure/logging/middleware.go

73 lines
1.7 KiB
Go

package logging
import (
"context"
"crypto/rand"
"fmt"
"log/slog"
"net/http"
"time"
)
type contextKey string
const requestIDKey contextKey = "request_id"
// RequestID generates a short unique request ID.
func RequestID() string {
b := make([]byte, 8)
rand.Read(b)
return fmt.Sprintf("%x", b)
}
// WithRequestID returns a context with a request ID attached.
func WithRequestID(ctx context.Context, id string) context.Context {
return context.WithValue(ctx, requestIDKey, id)
}
// GetRequestID extracts the request ID from context (empty if not set).
func GetRequestID(ctx context.Context) string {
if id, ok := ctx.Value(requestIDKey).(string); ok {
return id
}
return ""
}
// RequestIDMiddleware injects a unique request ID into each request context
// and logs request start/end with duration.
func RequestIDMiddleware(logger *slog.Logger, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqID := r.Header.Get("X-Request-ID")
if reqID == "" {
reqID = RequestID()
}
w.Header().Set("X-Request-ID", reqID)
ctx := WithRequestID(r.Context(), reqID)
r = r.WithContext(ctx)
start := time.Now()
wrapped := &statusWriter{ResponseWriter: w, status: 200}
next.ServeHTTP(wrapped, r)
dur := time.Since(start)
logger.Info("http_request",
"method", r.Method,
"path", r.URL.Path,
"status", wrapped.status,
"duration_ms", dur.Milliseconds(),
"request_id", reqID,
)
})
}
// statusWriter wraps ResponseWriter to capture status code.
type statusWriter struct {
http.ResponseWriter
status int
}
func (w *statusWriter) WriteHeader(code int) {
w.status = code
w.ResponseWriter.WriteHeader(code)
}