gomcp/internal/application/shadow_ai/fallback.go

152 lines
4.5 KiB
Go

// Copyright 2026 Syntrex Lab. All rights reserved.
// Use of this source code is governed by an Apache-2.0 license
// that can be found in the LICENSE file.
package shadow_ai
import (
"context"
"fmt"
"log/slog"
"time"
)
// FallbackManager provides priority-based enforcement with graceful degradation.
// Tries enforcement points in priority order; falls back to detect_only if all are offline.
type FallbackManager struct {
registry *PluginRegistry
priority []PluginType // e.g., ["proxy", "firewall", "edr"]
strategy string // "detect_only" | "alert_only"
logger *slog.Logger
// Event logging for detect-only fallback.
eventLogFn func(event ShadowAIEvent)
}
// NewFallbackManager creates a new fallback manager with the given enforcement priority.
func NewFallbackManager(registry *PluginRegistry, strategy string) *FallbackManager {
if strategy == "" {
strategy = "detect_only"
}
return &FallbackManager{
registry: registry,
priority: []PluginType{PluginTypeProxy, PluginTypeFirewall, PluginTypeEDR},
strategy: strategy,
logger: slog.Default().With("component", "shadow-ai-fallback"),
}
}
// SetEventLogger sets the callback for logging detection-only events.
func (fm *FallbackManager) SetEventLogger(fn func(ShadowAIEvent)) {
fm.eventLogFn = fn
}
// BlockDomain attempts to block a domain using the highest-priority healthy plugin.
// Returns the vendor that enforced, or falls back to detect_only mode.
func (fm *FallbackManager) BlockDomain(ctx context.Context, domain, reason string) (enforcedBy string, err error) {
for _, pType := range fm.priority {
plugins := fm.registry.GetByType(pType)
for _, plugin := range plugins {
ne, ok := plugin.(NetworkEnforcer)
if !ok {
// Try WebGateway for URL-based blocking.
if wg, ok := plugin.(WebGateway); ok {
vendor := wg.Vendor()
if !fm.registry.IsHealthy(vendor) {
continue
}
if err := wg.BlockURL(ctx, domain, reason); err != nil {
fm.logger.Warn("block failed on gateway", "vendor", vendor, "error", err)
continue
}
return vendor, nil
}
continue
}
vendor := ne.Vendor()
if !fm.registry.IsHealthy(vendor) {
continue
}
if err := ne.BlockDomain(ctx, domain, reason); err != nil {
fm.logger.Warn("block failed on enforcer", "vendor", vendor, "error", err)
continue
}
return vendor, nil
}
}
// All enforcement points unavailable — fallback.
fm.logger.Warn("all enforcement points unavailable, falling to detect_only",
"domain", domain,
"strategy", fm.strategy,
)
fm.logDetectOnly(domain, reason)
return "", nil
}
// BlockIP attempts to block an IP using the highest-priority healthy firewall.
func (fm *FallbackManager) BlockIP(ctx context.Context, ip string, duration time.Duration, reason string) (enforcedBy string, err error) {
enforcers := fm.registry.GetNetworkEnforcers()
for _, ne := range enforcers {
vendor := ne.Vendor()
if !fm.registry.IsHealthy(vendor) {
continue
}
if err := ne.BlockIP(ctx, ip, duration, reason); err != nil {
fm.logger.Warn("block IP failed", "vendor", vendor, "error", err)
continue
}
return vendor, nil
}
fm.logger.Warn("no healthy enforcer for IP block, falling to detect_only",
"ip", ip,
"strategy", fm.strategy,
)
fm.logDetectOnly(ip, reason)
return "", nil
}
// IsolateHost attempts to isolate a host using the highest-priority healthy EDR.
func (fm *FallbackManager) IsolateHost(ctx context.Context, hostname string) (enforcedBy string, err error) {
controllers := fm.registry.GetEndpointControllers()
for _, ec := range controllers {
vendor := ec.Vendor()
if !fm.registry.IsHealthy(vendor) {
continue
}
if err := ec.IsolateHost(ctx, hostname); err != nil {
fm.logger.Warn("isolate failed", "vendor", vendor, "error", err)
continue
}
return vendor, nil
}
fm.logger.Warn("no healthy EDR for host isolation, falling to detect_only",
"hostname", hostname,
"strategy", fm.strategy,
)
return "", fmt.Errorf("no healthy EDR available for host isolation")
}
// logDetectOnly records a detection-only event when no enforcement is possible.
func (fm *FallbackManager) logDetectOnly(target, reason string) {
if fm.eventLogFn != nil {
fm.eventLogFn(ShadowAIEvent{
Destination: target,
DetectionMethod: DetectNetwork,
Action: "detect_only",
Metadata: map[string]string{
"reason": reason,
"fallback_strategy": fm.strategy,
},
Timestamp: time.Now(),
})
}
}
// Strategy returns the configured fallback strategy.
func (fm *FallbackManager) Strategy() string {
return fm.strategy
}