diff --git a/cmd/soc/main.go b/cmd/soc/main.go index ee6069d..b2cc44c 100644 --- a/cmd/soc/main.go +++ b/cmd/soc/main.go @@ -19,6 +19,7 @@ import ( "runtime" "runtime/debug" "strconv" + "strings" "syscall" "github.com/syntrex/gomcp/internal/application/soc" @@ -122,6 +123,16 @@ func main() { socSvc := soc.NewService(socRepo, decisionLogger) srv := sochttp.New(socSvc, port) + // Configure CORS + if corsEnv := env("SOC_CORS_ORIGIN", ""); corsEnv != "" { + parts := strings.Split(corsEnv, ",") + for i := range parts { + parts[i] = strings.TrimSpace(parts[i]) + } + srv.SetCORSOrigins(parts) + slog.Info("CORS strict origins configured", "origins", len(parts)) + } + // Threat Intelligence Store — always initialized for IOC enrichment (§6) threatIntelStore := soc.NewThreatIntelStore() threatIntelStore.AddDefaultFeeds() diff --git a/internal/transport/http/middleware.go b/internal/transport/http/middleware.go index cf52256..29ddeca 100644 --- a/internal/transport/http/middleware.go +++ b/internal/transport/http/middleware.go @@ -2,57 +2,48 @@ package httpserver import ( "net/http" - "os" - "strings" ) -// corsAllowedOrigins returns the configured CORS origins. -// Set SOC_CORS_ORIGIN in production (e.g. "https://syntrex.pro,https://xn--80akacl3adqr.xn--p1acf"). -// Defaults to "*" for local development. -func corsAllowedOrigins() []string { - if v := os.Getenv("SOC_CORS_ORIGIN"); v != "" { - parts := strings.Split(v, ",") - for i := range parts { - parts[i] = strings.TrimSpace(parts[i]) - } - return parts - } - return []string{"*"} -} - -// corsMiddleware adds CORS headers with configurable origin. -// Production: set SOC_CORS_ORIGIN=https://syntrex.pro,https://xn--80akacl3adqr.xn--p1acf -func corsMiddleware(next http.Handler) http.Handler { - origins := corsAllowedOrigins() - allowAll := len(origins) == 1 && origins[0] == "*" +// corsMiddleware adds CORS headers with strict origin validation. +// Production: SetCORSOrigins should be called with ["https://syntrex.pro"] +func corsMiddleware(origins []string) func(http.Handler) http.Handler { + allowAll := false allowedSet := make(map[string]bool, len(origins)) for _, o := range origins { + if o == "*" { + allowAll = true + } allowedSet[o] = true } - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if allowAll { - w.Header().Set("Access-Control-Allow-Origin", "*") - } else { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { reqOrigin := r.Header.Get("Origin") - if allowedSet[reqOrigin] { + if reqOrigin != "" { + if !allowAll && !allowedSet[reqOrigin] { + http.Error(w, "CORS origin not allowed", http.StatusForbidden) + return + } w.Header().Set("Access-Control-Allow-Origin", reqOrigin) w.Header().Set("Vary", "Origin") + } else if allowAll { + w.Header().Set("Access-Control-Allow-Origin", "*") } - } - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key") - w.Header().Set("Access-Control-Allow-Credentials", "true") - w.Header().Set("Access-Control-Max-Age", "86400") + + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key") + w.Header().Set("Access-Control-Allow-Credentials", "true") + w.Header().Set("Access-Control-Max-Age", "86400") - // Handle preflight - if r.Method == http.MethodOptions { - w.WriteHeader(http.StatusNoContent) - return - } + // Handle preflight + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } - next.ServeHTTP(w, r) - }) + next.ServeHTTP(w, r) + }) + } } // securityHeadersMiddleware adds defense-in-depth headers to all responses. diff --git a/internal/transport/http/server.go b/internal/transport/http/server.go index 002a960..b8b557c 100644 --- a/internal/transport/http/server.go +++ b/internal/transport/http/server.go @@ -53,6 +53,7 @@ type Server struct { srv *http.Server tlsCert string tlsKey string + corsOrigins []string } // cachedScan stores a cached scan result with expiry. @@ -80,6 +81,14 @@ func New(socSvc *appsoc.Service, port int) *Server { wsHub: NewWSHub(), scanSem: make(chan struct{}, 6), // Max 6 concurrent scans (~2 per CPU) scanCache: make(map[string]*cachedScan, 500), + corsOrigins: []string{"http://localhost:3000", "https://syntrex.pro"}, // Default secure fallback + } +} + +// SetCORSOrigins configures the allowed origins for CORS strictly. +func (s *Server) SetCORSOrigins(origins []string) { + if len(origins) > 0 { + s.corsOrigins = origins } } @@ -378,7 +387,7 @@ func (s *Server) Start(ctx context.Context) error { if s.jwtAuth != nil { handler = s.jwtAuth.Middleware(handler) } - handler = corsMiddleware(handler) + handler = corsMiddleware(s.corsOrigins)(handler) handler = securityHeadersMiddleware(handler) handler = s.rateLimiter.Middleware(handler) handler = s.metrics.Middleware(handler) diff --git a/internal/transport/http/soc_handlers_test.go b/internal/transport/http/soc_handlers_test.go index 3ebe1d3..f1fea13 100644 --- a/internal/transport/http/soc_handlers_test.go +++ b/internal/transport/http/soc_handlers_test.go @@ -63,7 +63,7 @@ func newTestServer(t *testing.T) (*httptest.Server, *appsoc.Service) { mux.HandleFunc("GET /api/soc/incident-explain/{id}", srv.handleIncidentExplain) mux.HandleFunc("GET /health", srv.handleHealth) - ts := httptest.NewServer(corsMiddleware(mux)) + ts := httptest.NewServer(corsMiddleware([]string{"*"})(mux)) t.Cleanup(ts.Close) return ts, socSvc