mirror of
https://github.com/syntrex-lab/gomcp.git
synced 2026-05-03 08:12:37 +02:00
Release prep: 54 engines, self-hosted signatures, i18n, dashboard updates
This commit is contained in:
parent
694e32be26
commit
41cbfd6e0a
178 changed files with 36008 additions and 399 deletions
165
internal/application/resilience/behavioral.go
Normal file
165
internal/application/resilience/behavioral.go
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BehaviorProfile captures the runtime behavior of a component.
|
||||
type BehaviorProfile struct {
|
||||
Goroutines int `json:"goroutines"`
|
||||
HeapAllocMB float64 `json:"heap_alloc_mb"`
|
||||
HeapObjectsK float64 `json:"heap_objects_k"`
|
||||
GCPauseMs float64 `json:"gc_pause_ms"`
|
||||
NumGC uint32 `json:"num_gc"`
|
||||
FileDescriptors int `json:"file_descriptors,omitempty"`
|
||||
CustomMetrics map[string]float64 `json:"custom_metrics,omitempty"`
|
||||
}
|
||||
|
||||
// BehavioralAlert is emitted when a behavioral anomaly is detected.
|
||||
type BehavioralAlert struct {
|
||||
Component string `json:"component"`
|
||||
AnomalyType string `json:"anomaly_type"` // goroutine_leak, memory_leak, gc_pressure, etc.
|
||||
Metric string `json:"metric"`
|
||||
Current float64 `json:"current"`
|
||||
Baseline float64 `json:"baseline"`
|
||||
ZScore float64 `json:"z_score"`
|
||||
Severity string `json:"severity"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// BehavioralAnalyzer provides Go-side runtime behavioral analysis.
|
||||
// It profiles the current process and compares against learned baselines.
|
||||
// On Linux, eBPF hooks (immune/resilience_hooks.c) extend this to kernel level.
|
||||
type BehavioralAnalyzer struct {
|
||||
mu sync.RWMutex
|
||||
metricsDB *MetricsDB
|
||||
alertBus chan BehavioralAlert
|
||||
interval time.Duration
|
||||
component string // self component name
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewBehavioralAnalyzer creates a new behavioral analyzer.
|
||||
func NewBehavioralAnalyzer(component string, alertBufSize int) *BehavioralAnalyzer {
|
||||
if alertBufSize <= 0 {
|
||||
alertBufSize = 50
|
||||
}
|
||||
return &BehavioralAnalyzer{
|
||||
metricsDB: NewMetricsDB(DefaultMetricsWindow, DefaultMetricsMaxSize),
|
||||
alertBus: make(chan BehavioralAlert, alertBufSize),
|
||||
interval: 1 * time.Minute,
|
||||
component: component,
|
||||
logger: slog.Default().With("component", "sarl-behavioral"),
|
||||
}
|
||||
}
|
||||
|
||||
// AlertBus returns the channel for consuming behavioral alerts.
|
||||
func (ba *BehavioralAnalyzer) AlertBus() <-chan BehavioralAlert {
|
||||
return ba.alertBus
|
||||
}
|
||||
|
||||
// Start begins continuous behavioral monitoring. Blocks until ctx cancelled.
|
||||
func (ba *BehavioralAnalyzer) Start(ctx context.Context) {
|
||||
ba.logger.Info("behavioral analyzer started", "interval", ba.interval)
|
||||
|
||||
ticker := time.NewTicker(ba.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
ba.logger.Info("behavioral analyzer stopped")
|
||||
return
|
||||
case <-ticker.C:
|
||||
ba.collectAndAnalyze()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectAndAnalyze profiles runtime and checks for anomalies.
|
||||
func (ba *BehavioralAnalyzer) collectAndAnalyze() {
|
||||
profile := ba.collectProfile()
|
||||
ba.storeMetrics(profile)
|
||||
ba.detectAnomalies(profile)
|
||||
}
|
||||
|
||||
// collectProfile gathers current Go runtime stats.
|
||||
func (ba *BehavioralAnalyzer) collectProfile() BehaviorProfile {
|
||||
var mem runtime.MemStats
|
||||
runtime.ReadMemStats(&mem)
|
||||
|
||||
return BehaviorProfile{
|
||||
Goroutines: runtime.NumGoroutine(),
|
||||
HeapAllocMB: float64(mem.HeapAlloc) / (1024 * 1024),
|
||||
HeapObjectsK: float64(mem.HeapObjects) / 1000,
|
||||
GCPauseMs: float64(mem.PauseNs[(mem.NumGC+255)%256]) / 1e6,
|
||||
NumGC: mem.NumGC,
|
||||
}
|
||||
}
|
||||
|
||||
// storeMetrics records profile data in the time-series DB.
|
||||
func (ba *BehavioralAnalyzer) storeMetrics(p BehaviorProfile) {
|
||||
ba.metricsDB.AddDataPoint(ba.component, "goroutines", float64(p.Goroutines))
|
||||
ba.metricsDB.AddDataPoint(ba.component, "heap_alloc_mb", p.HeapAllocMB)
|
||||
ba.metricsDB.AddDataPoint(ba.component, "heap_objects_k", p.HeapObjectsK)
|
||||
ba.metricsDB.AddDataPoint(ba.component, "gc_pause_ms", p.GCPauseMs)
|
||||
}
|
||||
|
||||
// detectAnomalies checks each metric against its baseline via Z-score.
|
||||
func (ba *BehavioralAnalyzer) detectAnomalies(p BehaviorProfile) {
|
||||
checks := []struct {
|
||||
metric string
|
||||
value float64
|
||||
anomalyType string
|
||||
severity string
|
||||
}{
|
||||
{"goroutines", float64(p.Goroutines), "goroutine_leak", "WARNING"},
|
||||
{"heap_alloc_mb", p.HeapAllocMB, "memory_leak", "CRITICAL"},
|
||||
{"heap_objects_k", p.HeapObjectsK, "object_leak", "WARNING"},
|
||||
{"gc_pause_ms", p.GCPauseMs, "gc_pressure", "WARNING"},
|
||||
}
|
||||
|
||||
for _, c := range checks {
|
||||
baseline := ba.metricsDB.GetBaseline(ba.component, c.metric, DefaultMetricsWindow)
|
||||
if !IsAnomaly(c.value, baseline, AnomalyZScoreThreshold) {
|
||||
continue
|
||||
}
|
||||
|
||||
zscore := CalculateZScore(c.value, baseline)
|
||||
alert := BehavioralAlert{
|
||||
Component: ba.component,
|
||||
AnomalyType: c.anomalyType,
|
||||
Metric: c.metric,
|
||||
Current: c.value,
|
||||
Baseline: baseline.Mean,
|
||||
ZScore: zscore,
|
||||
Severity: c.severity,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
select {
|
||||
case ba.alertBus <- alert:
|
||||
ba.logger.Warn("behavioral anomaly detected",
|
||||
"type", c.anomalyType,
|
||||
"metric", c.metric,
|
||||
"z_score", zscore,
|
||||
)
|
||||
default:
|
||||
ba.logger.Error("behavioral alert bus full")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// InjectMetric allows manually injecting a metric for testing.
|
||||
func (ba *BehavioralAnalyzer) InjectMetric(metric string, value float64) {
|
||||
ba.metricsDB.AddDataPoint(ba.component, metric, value)
|
||||
}
|
||||
|
||||
// CurrentProfile returns a snapshot of the current runtime profile.
|
||||
func (ba *BehavioralAnalyzer) CurrentProfile() BehaviorProfile {
|
||||
return ba.collectProfile()
|
||||
}
|
||||
206
internal/application/resilience/behavioral_test.go
Normal file
206
internal/application/resilience/behavioral_test.go
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IM-01: Goroutine leak detection.
|
||||
func TestBehavioral_IM01_GoroutineLeak(t *testing.T) {
|
||||
ba := NewBehavioralAnalyzer("soc-ingest", 10)
|
||||
|
||||
// Build baseline of 10 goroutines.
|
||||
for i := 0; i < 50; i++ {
|
||||
ba.InjectMetric("goroutines", 10)
|
||||
}
|
||||
|
||||
// Spike to 1000 goroutines — should trigger anomaly.
|
||||
ba.metricsDB.AddDataPoint("soc-ingest", "goroutines", 1000)
|
||||
profile := BehaviorProfile{Goroutines: 1000}
|
||||
ba.detectAnomalies(profile)
|
||||
|
||||
select {
|
||||
case alert := <-ba.alertBus:
|
||||
if alert.AnomalyType != "goroutine_leak" {
|
||||
t.Errorf("expected goroutine_leak, got %s", alert.AnomalyType)
|
||||
}
|
||||
if alert.ZScore <= 3 {
|
||||
t.Errorf("expected Z > 3, got %f", alert.ZScore)
|
||||
}
|
||||
default:
|
||||
t.Error("expected goroutine leak alert")
|
||||
}
|
||||
}
|
||||
|
||||
// IM-02: Memory leak detection.
|
||||
func TestBehavioral_IM02_MemoryLeak(t *testing.T) {
|
||||
ba := NewBehavioralAnalyzer("soc-correlate", 10)
|
||||
|
||||
// Baseline: 50 MB.
|
||||
for i := 0; i < 50; i++ {
|
||||
ba.InjectMetric("heap_alloc_mb", 50)
|
||||
}
|
||||
|
||||
// Spike to 500 MB.
|
||||
ba.metricsDB.AddDataPoint("soc-correlate", "heap_alloc_mb", 500)
|
||||
profile := BehaviorProfile{HeapAllocMB: 500}
|
||||
ba.detectAnomalies(profile)
|
||||
|
||||
select {
|
||||
case alert := <-ba.alertBus:
|
||||
if alert.AnomalyType != "memory_leak" {
|
||||
t.Errorf("expected memory_leak, got %s", alert.AnomalyType)
|
||||
}
|
||||
if alert.Severity != "CRITICAL" {
|
||||
t.Errorf("expected CRITICAL severity, got %s", alert.Severity)
|
||||
}
|
||||
default:
|
||||
t.Error("expected memory leak alert")
|
||||
}
|
||||
}
|
||||
|
||||
// IM-03: GC pressure detection.
|
||||
func TestBehavioral_IM03_GCPressure(t *testing.T) {
|
||||
ba := NewBehavioralAnalyzer("soc-respond", 10)
|
||||
|
||||
// Baseline: 1ms GC pause.
|
||||
for i := 0; i < 50; i++ {
|
||||
ba.InjectMetric("gc_pause_ms", 1)
|
||||
}
|
||||
|
||||
// Spike to 100ms.
|
||||
ba.metricsDB.AddDataPoint("soc-respond", "gc_pause_ms", 100)
|
||||
profile := BehaviorProfile{GCPauseMs: 100}
|
||||
ba.detectAnomalies(profile)
|
||||
|
||||
select {
|
||||
case alert := <-ba.alertBus:
|
||||
if alert.AnomalyType != "gc_pressure" {
|
||||
t.Errorf("expected gc_pressure, got %s", alert.AnomalyType)
|
||||
}
|
||||
default:
|
||||
t.Error("expected gc_pressure alert")
|
||||
}
|
||||
}
|
||||
|
||||
// IM-04: Object leak detection.
|
||||
func TestBehavioral_IM04_ObjectLeak(t *testing.T) {
|
||||
ba := NewBehavioralAnalyzer("shield", 10)
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
ba.InjectMetric("heap_objects_k", 100)
|
||||
}
|
||||
|
||||
ba.metricsDB.AddDataPoint("shield", "heap_objects_k", 5000)
|
||||
profile := BehaviorProfile{HeapObjectsK: 5000}
|
||||
ba.detectAnomalies(profile)
|
||||
|
||||
select {
|
||||
case alert := <-ba.alertBus:
|
||||
if alert.AnomalyType != "object_leak" {
|
||||
t.Errorf("expected object_leak, got %s", alert.AnomalyType)
|
||||
}
|
||||
default:
|
||||
t.Error("expected object leak alert")
|
||||
}
|
||||
}
|
||||
|
||||
// IM-05: Normal behavior — no alerts.
|
||||
func TestBehavioral_IM05_NormalBehavior(t *testing.T) {
|
||||
ba := NewBehavioralAnalyzer("sidecar", 10)
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
ba.InjectMetric("goroutines", 10)
|
||||
ba.InjectMetric("heap_alloc_mb", 50)
|
||||
ba.InjectMetric("heap_objects_k", 100)
|
||||
ba.InjectMetric("gc_pause_ms", 1)
|
||||
}
|
||||
|
||||
profile := BehaviorProfile{
|
||||
Goroutines: 10,
|
||||
HeapAllocMB: 50,
|
||||
HeapObjectsK: 100,
|
||||
GCPauseMs: 1,
|
||||
}
|
||||
ba.detectAnomalies(profile)
|
||||
|
||||
select {
|
||||
case alert := <-ba.alertBus:
|
||||
t.Errorf("expected no alerts for normal behavior, got %+v", alert)
|
||||
default:
|
||||
// Good — no alerts.
|
||||
}
|
||||
}
|
||||
|
||||
// IM-06: Start/Stop lifecycle.
|
||||
func TestBehavioral_IM06_StartStop(t *testing.T) {
|
||||
ba := NewBehavioralAnalyzer("test", 10)
|
||||
ba.interval = 50 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
ba.Start(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Start() did not return after context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
// IM-07: CurrentProfile returns valid data.
|
||||
func TestBehavioral_IM07_CurrentProfile(t *testing.T) {
|
||||
ba := NewBehavioralAnalyzer("test", 10)
|
||||
profile := ba.CurrentProfile()
|
||||
|
||||
if profile.Goroutines <= 0 {
|
||||
t.Error("expected positive goroutine count")
|
||||
}
|
||||
if profile.HeapAllocMB <= 0 {
|
||||
t.Error("expected positive heap alloc")
|
||||
}
|
||||
}
|
||||
|
||||
// IM-08: Alert bus overflow (non-blocking).
|
||||
func TestBehavioral_IM08_AlertBusOverflow(t *testing.T) {
|
||||
ba := NewBehavioralAnalyzer("test", 2)
|
||||
|
||||
// Fill bus.
|
||||
ba.alertBus <- BehavioralAlert{AnomalyType: "fill1"}
|
||||
ba.alertBus <- BehavioralAlert{AnomalyType: "fill2"}
|
||||
|
||||
// Build baseline.
|
||||
for i := 0; i < 50; i++ {
|
||||
ba.InjectMetric("goroutines", 10)
|
||||
}
|
||||
|
||||
// This should not panic.
|
||||
ba.metricsDB.AddDataPoint("test", "goroutines", 10000)
|
||||
ba.detectAnomalies(BehaviorProfile{Goroutines: 10000})
|
||||
}
|
||||
|
||||
// Test collectAndAnalyze runs without error.
|
||||
func TestBehavioral_CollectAndAnalyze(t *testing.T) {
|
||||
ba := NewBehavioralAnalyzer("test", 10)
|
||||
// Should not panic.
|
||||
ba.collectAndAnalyze()
|
||||
}
|
||||
|
||||
// Test InjectMetric stores data.
|
||||
func TestBehavioral_InjectMetric(t *testing.T) {
|
||||
ba := NewBehavioralAnalyzer("test", 10)
|
||||
ba.InjectMetric("custom", 42.0)
|
||||
|
||||
recent := ba.metricsDB.GetRecent("test", "custom", 1)
|
||||
if len(recent) != 1 || recent[0].Value != 42.0 {
|
||||
t.Errorf("expected 42.0, got %v", recent)
|
||||
}
|
||||
}
|
||||
524
internal/application/resilience/healing_engine.go
Normal file
524
internal/application/resilience/healing_engine.go
Normal file
|
|
@ -0,0 +1,524 @@
|
|||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HealingState represents the FSM state of a healing operation.
|
||||
type HealingState string
|
||||
|
||||
const (
|
||||
HealingIdle HealingState = "IDLE"
|
||||
HealingDiagnosing HealingState = "DIAGNOSING"
|
||||
HealingActive HealingState = "HEALING"
|
||||
HealingVerifying HealingState = "VERIFYING"
|
||||
HealingCompleted HealingState = "COMPLETED"
|
||||
HealingFailed HealingState = "FAILED"
|
||||
)
|
||||
|
||||
// HealingResult summarizes a completed healing operation.
|
||||
type HealingResult string
|
||||
|
||||
const (
|
||||
ResultSuccess HealingResult = "SUCCESS"
|
||||
ResultFailed HealingResult = "FAILED"
|
||||
ResultSkipped HealingResult = "SKIPPED"
|
||||
)
|
||||
|
||||
// ActionType defines the kinds of healing actions.
|
||||
type ActionType string
|
||||
|
||||
const (
|
||||
ActionGracefulStop ActionType = "graceful_stop"
|
||||
ActionClearTempFiles ActionType = "clear_temp_files"
|
||||
ActionStartComponent ActionType = "start_component"
|
||||
ActionVerifyHealth ActionType = "verify_health"
|
||||
ActionNotifySOC ActionType = "notify_soc"
|
||||
ActionFreezeConfig ActionType = "freeze_config"
|
||||
ActionRollbackConfig ActionType = "rollback_config"
|
||||
ActionVerifyConfig ActionType = "verify_config"
|
||||
ActionSwitchReadOnly ActionType = "switch_to_readonly"
|
||||
ActionBackupDB ActionType = "backup_db"
|
||||
ActionRestoreSnapshot ActionType = "restore_snapshot"
|
||||
ActionVerifyIntegrity ActionType = "verify_integrity"
|
||||
ActionResumeWrites ActionType = "resume_writes"
|
||||
ActionDisableRules ActionType = "disable_rules"
|
||||
ActionRevertRules ActionType = "revert_rules"
|
||||
ActionReloadEngine ActionType = "reload_engine"
|
||||
ActionIsolateNetwork ActionType = "isolate_network"
|
||||
ActionRegenCerts ActionType = "regenerate_certs"
|
||||
ActionRestoreNetwork ActionType = "restore_network"
|
||||
ActionNotifyArchitect ActionType = "notify_architect"
|
||||
ActionEnterSafeMode ActionType = "enter_safe_mode"
|
||||
)
|
||||
|
||||
// Action is a single step in a healing strategy.
|
||||
type Action struct {
|
||||
Type ActionType `json:"type"`
|
||||
Params map[string]interface{} `json:"params,omitempty"`
|
||||
Timeout time.Duration `json:"timeout"`
|
||||
OnError string `json:"on_error"` // "continue", "abort", "rollback"
|
||||
}
|
||||
|
||||
// TriggerCondition defines when a healing strategy activates.
|
||||
type TriggerCondition struct {
|
||||
Metrics []string `json:"metrics,omitempty"`
|
||||
Statuses []ComponentStatus `json:"statuses,omitempty"`
|
||||
ConsecutiveFailures int `json:"consecutive_failures"`
|
||||
WithinWindow time.Duration `json:"within_window"`
|
||||
}
|
||||
|
||||
// RollbackPlan defines what happens if healing fails.
|
||||
type RollbackPlan struct {
|
||||
OnFailure string `json:"on_failure"` // "escalate", "enter_safe_mode", "maintain_isolation"
|
||||
Actions []Action `json:"actions,omitempty"`
|
||||
}
|
||||
|
||||
// HealingStrategy is a complete self-healing plan.
|
||||
type HealingStrategy struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Trigger TriggerCondition `json:"trigger"`
|
||||
Actions []Action `json:"actions"`
|
||||
Rollback RollbackPlan `json:"rollback"`
|
||||
MaxAttempts int `json:"max_attempts"`
|
||||
Cooldown time.Duration `json:"cooldown"`
|
||||
}
|
||||
|
||||
// Diagnosis is the result of root cause analysis.
|
||||
type Diagnosis struct {
|
||||
Component string `json:"component"`
|
||||
Metric string `json:"metric"`
|
||||
RootCause string `json:"root_cause"`
|
||||
Confidence float64 `json:"confidence"`
|
||||
SuggestedFix string `json:"suggested_fix"`
|
||||
RelatedAlerts []HealthAlert `json:"related_alerts,omitempty"`
|
||||
}
|
||||
|
||||
// HealingOperation tracks a single healing attempt.
|
||||
type HealingOperation struct {
|
||||
ID string `json:"id"`
|
||||
StrategyID string `json:"strategy_id"`
|
||||
Component string `json:"component"`
|
||||
State HealingState `json:"state"`
|
||||
Diagnosis *Diagnosis `json:"diagnosis,omitempty"`
|
||||
ActionsRun []ActionLog `json:"actions_run"`
|
||||
Result HealingResult `json:"result"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
CompletedAt time.Time `json:"completed_at,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
AttemptNumber int `json:"attempt_number"`
|
||||
}
|
||||
|
||||
// ActionLog records the execution of a single action.
|
||||
type ActionLog struct {
|
||||
Action ActionType `json:"action"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ActionExecutorFunc is the callback that actually runs an action.
|
||||
// Implementations handle the real system operations (restart, rollback, etc.).
|
||||
type ActionExecutorFunc func(ctx context.Context, action Action, component string) error
|
||||
|
||||
// HealingEngine is the L2 Self-Healing orchestrator.
|
||||
type HealingEngine struct {
|
||||
mu sync.RWMutex
|
||||
strategies []HealingStrategy
|
||||
cooldowns map[string]time.Time // strategyID → earliest next run
|
||||
operations []*HealingOperation
|
||||
opCounter int64
|
||||
executor ActionExecutorFunc
|
||||
alertBus <-chan HealthAlert
|
||||
escalateFn func(HealthAlert) // called on unrecoverable failure
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewHealingEngine creates a new self-healing engine.
|
||||
func NewHealingEngine(
|
||||
alertBus <-chan HealthAlert,
|
||||
executor ActionExecutorFunc,
|
||||
escalateFn func(HealthAlert),
|
||||
) *HealingEngine {
|
||||
return &HealingEngine{
|
||||
cooldowns: make(map[string]time.Time),
|
||||
operations: make([]*HealingOperation, 0),
|
||||
executor: executor,
|
||||
alertBus: alertBus,
|
||||
escalateFn: escalateFn,
|
||||
logger: slog.Default().With("component", "sarl-healing-engine"),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterStrategy adds a healing strategy.
|
||||
func (he *HealingEngine) RegisterStrategy(s HealingStrategy) {
|
||||
he.mu.Lock()
|
||||
defer he.mu.Unlock()
|
||||
he.strategies = append(he.strategies, s)
|
||||
he.logger.Info("strategy registered", "id", s.ID, "name", s.Name)
|
||||
}
|
||||
|
||||
// Start begins listening for alerts and initiating healing. Blocks until ctx is cancelled.
|
||||
func (he *HealingEngine) Start(ctx context.Context) {
|
||||
he.logger.Info("healing engine started", "strategies", len(he.strategies))
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
he.logger.Info("healing engine stopped")
|
||||
return
|
||||
case alert, ok := <-he.alertBus:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if alert.Severity == SeverityCritical || alert.Severity == SeverityWarning {
|
||||
he.initiateHealing(ctx, alert)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// initiateHealing runs the healing pipeline for an alert.
|
||||
func (he *HealingEngine) initiateHealing(ctx context.Context, alert HealthAlert) {
|
||||
strategy := he.findStrategy(alert)
|
||||
if strategy == nil {
|
||||
he.logger.Info("no matching strategy for alert",
|
||||
"component", alert.Component,
|
||||
"metric", alert.Metric,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if he.isInCooldown(strategy.ID) {
|
||||
he.logger.Info("strategy in cooldown",
|
||||
"strategy", strategy.ID,
|
||||
"component", alert.Component,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
op := he.createOperation(strategy, alert.Component)
|
||||
|
||||
he.logger.Info("healing initiated",
|
||||
"op_id", op.ID,
|
||||
"strategy", strategy.ID,
|
||||
"component", alert.Component,
|
||||
)
|
||||
|
||||
// Phase 1: Diagnose.
|
||||
he.transitionOp(op, HealingDiagnosing)
|
||||
diagnosis := he.diagnose(alert)
|
||||
op.Diagnosis = &diagnosis
|
||||
|
||||
// Phase 2: Execute healing actions.
|
||||
he.transitionOp(op, HealingActive)
|
||||
execErr := he.executeActions(ctx, strategy, op)
|
||||
|
||||
// Phase 3: Verify recovery.
|
||||
if execErr == nil {
|
||||
he.transitionOp(op, HealingVerifying)
|
||||
verifyErr := he.verifyRecovery(ctx, strategy, op.Component)
|
||||
if verifyErr != nil {
|
||||
execErr = verifyErr
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 4: Complete or fail.
|
||||
if execErr == nil {
|
||||
he.transitionOp(op, HealingCompleted)
|
||||
op.Result = ResultSuccess
|
||||
he.logger.Info("healing completed successfully",
|
||||
"op_id", op.ID,
|
||||
"component", op.Component,
|
||||
"duration", time.Since(op.StartedAt),
|
||||
)
|
||||
} else {
|
||||
he.transitionOp(op, HealingFailed)
|
||||
op.Result = ResultFailed
|
||||
op.Error = execErr.Error()
|
||||
he.logger.Error("healing failed",
|
||||
"op_id", op.ID,
|
||||
"component", op.Component,
|
||||
"error", execErr,
|
||||
)
|
||||
|
||||
// Execute rollback.
|
||||
he.executeRollback(ctx, strategy, op)
|
||||
|
||||
// Escalate.
|
||||
if he.escalateFn != nil {
|
||||
he.escalateFn(alert)
|
||||
}
|
||||
}
|
||||
|
||||
op.CompletedAt = time.Now()
|
||||
he.setCooldown(strategy.ID, strategy.Cooldown)
|
||||
}
|
||||
|
||||
// findStrategy returns the first matching strategy for an alert.
|
||||
func (he *HealingEngine) findStrategy(alert HealthAlert) *HealingStrategy {
|
||||
he.mu.RLock()
|
||||
defer he.mu.RUnlock()
|
||||
|
||||
for i := range he.strategies {
|
||||
s := &he.strategies[i]
|
||||
if he.matchesTrigger(s.Trigger, alert) {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchesTrigger checks if an alert matches a strategy's trigger condition.
|
||||
func (he *HealingEngine) matchesTrigger(trigger TriggerCondition, alert HealthAlert) bool {
|
||||
// Match by metric name.
|
||||
for _, m := range trigger.Metrics {
|
||||
if m == alert.Metric {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Match by component status.
|
||||
for _, s := range trigger.Statuses {
|
||||
switch s {
|
||||
case StatusCritical:
|
||||
if alert.Severity == SeverityCritical {
|
||||
return true
|
||||
}
|
||||
case StatusOffline:
|
||||
if alert.Severity == SeverityCritical && alert.SuggestedAction == "restart" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isInCooldown checks if a strategy is still in its cooldown period.
|
||||
func (he *HealingEngine) isInCooldown(strategyID string) bool {
|
||||
he.mu.RLock()
|
||||
defer he.mu.RUnlock()
|
||||
|
||||
earliest, ok := he.cooldowns[strategyID]
|
||||
return ok && time.Now().Before(earliest)
|
||||
}
|
||||
|
||||
// setCooldown marks a strategy as cooling down.
|
||||
func (he *HealingEngine) setCooldown(strategyID string, duration time.Duration) {
|
||||
he.mu.Lock()
|
||||
defer he.mu.Unlock()
|
||||
he.cooldowns[strategyID] = time.Now().Add(duration)
|
||||
}
|
||||
|
||||
// createOperation creates and records a new healing operation.
|
||||
func (he *HealingEngine) createOperation(strategy *HealingStrategy, component string) *HealingOperation {
|
||||
he.mu.Lock()
|
||||
defer he.mu.Unlock()
|
||||
|
||||
he.opCounter++
|
||||
op := &HealingOperation{
|
||||
ID: fmt.Sprintf("heal-%d", he.opCounter),
|
||||
StrategyID: strategy.ID,
|
||||
Component: component,
|
||||
State: HealingIdle,
|
||||
StartedAt: time.Now(),
|
||||
ActionsRun: make([]ActionLog, 0),
|
||||
}
|
||||
he.operations = append(he.operations, op)
|
||||
return op
|
||||
}
|
||||
|
||||
// transitionOp moves an operation to a new state.
|
||||
func (he *HealingEngine) transitionOp(op *HealingOperation, newState HealingState) {
|
||||
he.logger.Debug("healing state transition",
|
||||
"op_id", op.ID,
|
||||
"from", op.State,
|
||||
"to", newState,
|
||||
)
|
||||
op.State = newState
|
||||
}
|
||||
|
||||
// diagnose performs root cause analysis for an alert.
|
||||
func (he *HealingEngine) diagnose(alert HealthAlert) Diagnosis {
|
||||
rootCause := "unknown"
|
||||
confidence := 0.5
|
||||
suggestedFix := "restart component"
|
||||
|
||||
switch {
|
||||
case alert.Metric == "memory" && alert.Current > 90:
|
||||
rootCause = "memory_exhaustion"
|
||||
confidence = 0.9
|
||||
suggestedFix = "restart with increased limits"
|
||||
case alert.Metric == "cpu" && alert.Current > 90:
|
||||
rootCause = "cpu_saturation"
|
||||
confidence = 0.8
|
||||
suggestedFix = "check for runaway goroutines"
|
||||
case alert.Metric == "error_rate":
|
||||
rootCause = "elevated_error_rate"
|
||||
confidence = 0.7
|
||||
suggestedFix = "check dependencies and config"
|
||||
case alert.Metric == "latency_p99":
|
||||
rootCause = "latency_degradation"
|
||||
confidence = 0.6
|
||||
suggestedFix = "check database and network"
|
||||
case alert.Metric == "quorum":
|
||||
rootCause = "quorum_loss"
|
||||
confidence = 0.95
|
||||
suggestedFix = "activate safe mode"
|
||||
default:
|
||||
rootCause = fmt.Sprintf("threshold_breach_%s", alert.Metric)
|
||||
confidence = 0.5
|
||||
suggestedFix = "investigate manually"
|
||||
}
|
||||
|
||||
return Diagnosis{
|
||||
Component: alert.Component,
|
||||
Metric: alert.Metric,
|
||||
RootCause: rootCause,
|
||||
Confidence: confidence,
|
||||
SuggestedFix: suggestedFix,
|
||||
}
|
||||
}
|
||||
|
||||
// executeActions runs each action in sequence.
|
||||
func (he *HealingEngine) executeActions(ctx context.Context, strategy *HealingStrategy, op *HealingOperation) error {
|
||||
for _, action := range strategy.Actions {
|
||||
actionCtx := ctx
|
||||
var cancel context.CancelFunc
|
||||
if action.Timeout > 0 {
|
||||
actionCtx, cancel = context.WithTimeout(ctx, action.Timeout)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
err := he.executor(actionCtx, action, op.Component)
|
||||
duration := time.Since(start)
|
||||
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
|
||||
logEntry := ActionLog{
|
||||
Action: action.Type,
|
||||
StartedAt: start,
|
||||
Duration: duration,
|
||||
Success: err == nil,
|
||||
}
|
||||
if err != nil {
|
||||
logEntry.Error = err.Error()
|
||||
}
|
||||
op.ActionsRun = append(op.ActionsRun, logEntry)
|
||||
|
||||
if err != nil {
|
||||
switch action.OnError {
|
||||
case "continue":
|
||||
he.logger.Warn("action failed, continuing",
|
||||
"action", action.Type,
|
||||
"error", err,
|
||||
)
|
||||
case "rollback":
|
||||
return fmt.Errorf("action %s failed (rollback): %w", action.Type, err)
|
||||
default: // "abort"
|
||||
return fmt.Errorf("action %s failed: %w", action.Type, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyRecovery checks if the component is healthy after healing.
|
||||
func (he *HealingEngine) verifyRecovery(ctx context.Context, strategy *HealingStrategy, component string) error {
|
||||
// Execute a verify_health action if not already in the strategy.
|
||||
verifyAction := Action{
|
||||
Type: ActionVerifyHealth,
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
return he.executor(ctx, verifyAction, component)
|
||||
}
|
||||
|
||||
// executeRollback runs the rollback plan for a failed healing.
|
||||
func (he *HealingEngine) executeRollback(ctx context.Context, strategy *HealingStrategy, op *HealingOperation) {
|
||||
if len(strategy.Rollback.Actions) == 0 {
|
||||
he.logger.Info("no rollback actions defined",
|
||||
"strategy", strategy.ID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
he.logger.Warn("executing rollback",
|
||||
"strategy", strategy.ID,
|
||||
"component", op.Component,
|
||||
)
|
||||
|
||||
for _, action := range strategy.Rollback.Actions {
|
||||
if err := he.executor(ctx, action, op.Component); err != nil {
|
||||
he.logger.Error("rollback action failed",
|
||||
"action", action.Type,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetOperation returns a healing operation by ID.
|
||||
// Returns a deep copy to prevent data races with the healing goroutine.
|
||||
func (he *HealingEngine) GetOperation(id string) (*HealingOperation, bool) {
|
||||
he.mu.RLock()
|
||||
defer he.mu.RUnlock()
|
||||
|
||||
for _, op := range he.operations {
|
||||
if op.ID == id {
|
||||
cp := *op
|
||||
cp.ActionsRun = make([]ActionLog, len(op.ActionsRun))
|
||||
copy(cp.ActionsRun, op.ActionsRun)
|
||||
if op.Diagnosis != nil {
|
||||
diag := *op.Diagnosis
|
||||
cp.Diagnosis = &diag
|
||||
}
|
||||
return &cp, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// RecentOperations returns the last N operations.
|
||||
// Returns deep copies to prevent data races with the healing goroutine.
|
||||
func (he *HealingEngine) RecentOperations(n int) []HealingOperation {
|
||||
he.mu.RLock()
|
||||
defer he.mu.RUnlock()
|
||||
|
||||
total := len(he.operations)
|
||||
if total == 0 {
|
||||
return nil
|
||||
}
|
||||
start := total - n
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
|
||||
result := make([]HealingOperation, 0, n)
|
||||
for i := start; i < total; i++ {
|
||||
cp := *he.operations[i]
|
||||
cp.ActionsRun = make([]ActionLog, len(he.operations[i].ActionsRun))
|
||||
copy(cp.ActionsRun, he.operations[i].ActionsRun)
|
||||
if he.operations[i].Diagnosis != nil {
|
||||
diag := *he.operations[i].Diagnosis
|
||||
cp.Diagnosis = &diag
|
||||
}
|
||||
result = append(result, cp)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// StrategyCount returns the number of registered strategies.
|
||||
func (he *HealingEngine) StrategyCount() int {
|
||||
he.mu.RLock()
|
||||
defer he.mu.RUnlock()
|
||||
return len(he.strategies)
|
||||
}
|
||||
588
internal/application/resilience/healing_engine_test.go
Normal file
588
internal/application/resilience/healing_engine_test.go
Normal file
|
|
@ -0,0 +1,588 @@
|
|||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- Mock executor for tests ---
|
||||
|
||||
type mockExecutorLog struct {
|
||||
actions []ActionType
|
||||
fail map[ActionType]bool
|
||||
count atomic.Int64
|
||||
}
|
||||
|
||||
func newMockExecutor() *mockExecutorLog {
|
||||
return &mockExecutorLog{
|
||||
fail: make(map[ActionType]bool),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockExecutorLog) execute(_ context.Context, action Action, _ string) error {
|
||||
m.count.Add(1)
|
||||
m.actions = append(m.actions, action.Type)
|
||||
if m.fail[action.Type] {
|
||||
return fmt.Errorf("action %s failed", action.Type)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Healing Engine Tests ---
|
||||
|
||||
// HE-01: Component restart (success).
|
||||
func TestHealingEngine_HE01_RestartSuccess(t *testing.T) {
|
||||
mock := newMockExecutor()
|
||||
alertCh := make(chan HealthAlert, 10)
|
||||
escalated := false
|
||||
|
||||
he := NewHealingEngine(alertCh, mock.execute, func(_ HealthAlert) {
|
||||
escalated = true
|
||||
})
|
||||
he.RegisterStrategy(RestartComponentStrategy())
|
||||
|
||||
alertCh <- HealthAlert{
|
||||
Component: "soc-ingest",
|
||||
Severity: SeverityCritical,
|
||||
Metric: "quorum",
|
||||
SuggestedAction: "restart",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Run one healing cycle.
|
||||
go he.Start(ctx)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
ops := he.RecentOperations(10)
|
||||
if len(ops) == 0 {
|
||||
t.Fatal("expected at least 1 operation")
|
||||
}
|
||||
if ops[0].Result != ResultSuccess {
|
||||
t.Errorf("expected SUCCESS, got %s (error: %s)", ops[0].Result, ops[0].Error)
|
||||
}
|
||||
if escalated {
|
||||
t.Error("should not have escalated on success")
|
||||
}
|
||||
}
|
||||
|
||||
// HE-02: Component restart (failure ×3 → escalate).
|
||||
func TestHealingEngine_HE02_RestartFailureEscalate(t *testing.T) {
|
||||
mock := newMockExecutor()
|
||||
mock.fail[ActionStartComponent] = true // Start always fails.
|
||||
|
||||
alertCh := make(chan HealthAlert, 10)
|
||||
escalated := false
|
||||
|
||||
he := NewHealingEngine(alertCh, mock.execute, func(_ HealthAlert) {
|
||||
escalated = true
|
||||
})
|
||||
he.RegisterStrategy(RestartComponentStrategy())
|
||||
|
||||
alertCh <- HealthAlert{
|
||||
Component: "soc-correlate",
|
||||
Severity: SeverityCritical,
|
||||
Metric: "quorum",
|
||||
SuggestedAction: "restart",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go he.Start(ctx)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
if !escalated {
|
||||
t.Error("expected escalation on failure")
|
||||
}
|
||||
|
||||
ops := he.RecentOperations(10)
|
||||
if len(ops) == 0 {
|
||||
t.Fatal("expected operation")
|
||||
}
|
||||
if ops[0].Result != ResultFailed {
|
||||
t.Errorf("expected FAILED, got %s", ops[0].Result)
|
||||
}
|
||||
}
|
||||
|
||||
// HE-03: Config rollback strategy matching.
|
||||
func TestHealingEngine_HE03_ConfigRollback(t *testing.T) {
|
||||
mock := newMockExecutor()
|
||||
alertCh := make(chan HealthAlert, 10)
|
||||
|
||||
he := NewHealingEngine(alertCh, mock.execute, nil)
|
||||
he.RegisterStrategy(RollbackConfigStrategy())
|
||||
|
||||
alertCh <- HealthAlert{
|
||||
Component: "soc-ingest",
|
||||
Severity: SeverityWarning,
|
||||
Metric: "config_tampering",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go he.Start(ctx)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
ops := he.RecentOperations(10)
|
||||
if len(ops) == 0 {
|
||||
t.Fatal("expected operation for config rollback")
|
||||
}
|
||||
if ops[0].StrategyID != "ROLLBACK_CONFIG" {
|
||||
t.Errorf("expected ROLLBACK_CONFIG, got %s", ops[0].StrategyID)
|
||||
}
|
||||
}
|
||||
|
||||
// HE-04: Database recovery.
|
||||
func TestHealingEngine_HE04_DatabaseRecovery(t *testing.T) {
|
||||
mock := newMockExecutor()
|
||||
alertCh := make(chan HealthAlert, 10)
|
||||
|
||||
he := NewHealingEngine(alertCh, mock.execute, nil)
|
||||
he.RegisterStrategy(RecoverDatabaseStrategy())
|
||||
|
||||
alertCh <- HealthAlert{
|
||||
Component: "soc-correlate",
|
||||
Severity: SeverityCritical,
|
||||
Metric: "database_corruption",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go he.Start(ctx)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
ops := he.RecentOperations(10)
|
||||
if len(ops) == 0 {
|
||||
t.Fatal("expected DB recovery op")
|
||||
}
|
||||
if ops[0].StrategyID != "RECOVER_DATABASE" {
|
||||
t.Errorf("expected RECOVER_DATABASE, got %s", ops[0].StrategyID)
|
||||
}
|
||||
}
|
||||
|
||||
// HE-05: Rule poisoning defense.
|
||||
func TestHealingEngine_HE05_RulePoisoning(t *testing.T) {
|
||||
mock := newMockExecutor()
|
||||
alertCh := make(chan HealthAlert, 10)
|
||||
|
||||
he := NewHealingEngine(alertCh, mock.execute, nil)
|
||||
he.RegisterStrategy(RecoverRulesStrategy())
|
||||
|
||||
alertCh <- HealthAlert{
|
||||
Component: "soc-correlate",
|
||||
Severity: SeverityWarning,
|
||||
Metric: "rule_execution_failure_rate",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go he.Start(ctx)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
ops := he.RecentOperations(10)
|
||||
if len(ops) == 0 {
|
||||
t.Fatal("expected rule recovery op")
|
||||
}
|
||||
if ops[0].StrategyID != "RECOVER_RULES" {
|
||||
t.Errorf("expected RECOVER_RULES, got %s", ops[0].StrategyID)
|
||||
}
|
||||
}
|
||||
|
||||
// HE-06: Network isolation recovery.
|
||||
func TestHealingEngine_HE06_NetworkRecovery(t *testing.T) {
|
||||
mock := newMockExecutor()
|
||||
alertCh := make(chan HealthAlert, 10)
|
||||
|
||||
he := NewHealingEngine(alertCh, mock.execute, nil)
|
||||
he.RegisterStrategy(RecoverNetworkStrategy())
|
||||
|
||||
alertCh <- HealthAlert{
|
||||
Component: "soc-respond",
|
||||
Severity: SeverityWarning,
|
||||
Metric: "network_partition",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go he.Start(ctx)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
ops := he.RecentOperations(10)
|
||||
if len(ops) == 0 {
|
||||
t.Fatal("expected network recovery op")
|
||||
}
|
||||
if ops[0].StrategyID != "RECOVER_NETWORK" {
|
||||
t.Errorf("expected RECOVER_NETWORK, got %s", ops[0].StrategyID)
|
||||
}
|
||||
}
|
||||
|
||||
// HE-07: Cooldown enforcement.
|
||||
func TestHealingEngine_HE07_Cooldown(t *testing.T) {
|
||||
mock := newMockExecutor()
|
||||
alertCh := make(chan HealthAlert, 10)
|
||||
|
||||
he := NewHealingEngine(alertCh, mock.execute, nil)
|
||||
he.RegisterStrategy(RestartComponentStrategy())
|
||||
|
||||
// Set cooldown manually.
|
||||
he.setCooldown("RESTART_COMPONENT", 1*time.Hour)
|
||||
|
||||
if !he.isInCooldown("RESTART_COMPONENT") {
|
||||
t.Error("expected cooldown active")
|
||||
}
|
||||
|
||||
alertCh <- HealthAlert{
|
||||
Component: "soc-ingest",
|
||||
Severity: SeverityCritical,
|
||||
Metric: "quorum",
|
||||
SuggestedAction: "restart",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
go he.Start(ctx)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
ops := he.RecentOperations(10)
|
||||
if len(ops) != 0 {
|
||||
t.Error("expected 0 operations during cooldown")
|
||||
}
|
||||
}
|
||||
|
||||
// HE-08: Rollback on failure.
|
||||
func TestHealingEngine_HE08_Rollback(t *testing.T) {
|
||||
mock := newMockExecutor()
|
||||
mock.fail[ActionStartComponent] = true
|
||||
|
||||
alertCh := make(chan HealthAlert, 10)
|
||||
he := NewHealingEngine(alertCh, mock.execute, func(_ HealthAlert) {})
|
||||
|
||||
strategy := RollbackConfigStrategy()
|
||||
he.RegisterStrategy(strategy)
|
||||
|
||||
alertCh <- HealthAlert{
|
||||
Component: "soc-ingest",
|
||||
Severity: SeverityWarning,
|
||||
Metric: "config_tampering",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go he.Start(ctx)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
// Rollback should have executed enter_safe_mode.
|
||||
foundSafeMode := false
|
||||
for _, a := range mock.actions {
|
||||
if a == ActionEnterSafeMode {
|
||||
foundSafeMode = true
|
||||
}
|
||||
}
|
||||
if !foundSafeMode {
|
||||
t.Errorf("expected safe mode in rollback, actions: %v", mock.actions)
|
||||
}
|
||||
}
|
||||
|
||||
// HE-09: State machine transitions.
|
||||
func TestHealingEngine_HE09_StateTransitions(t *testing.T) {
|
||||
mock := newMockExecutor()
|
||||
alertCh := make(chan HealthAlert, 10)
|
||||
|
||||
he := NewHealingEngine(alertCh, mock.execute, nil)
|
||||
he.RegisterStrategy(RestartComponentStrategy())
|
||||
|
||||
alertCh <- HealthAlert{
|
||||
Component: "comp",
|
||||
Severity: SeverityCritical,
|
||||
Metric: "quorum",
|
||||
SuggestedAction: "restart",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go he.Start(ctx)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
ops := he.RecentOperations(10)
|
||||
if len(ops) == 0 {
|
||||
t.Fatal("expected operation")
|
||||
}
|
||||
// Final state should be COMPLETED.
|
||||
if ops[0].State != HealingCompleted {
|
||||
t.Errorf("expected COMPLETED, got %s", ops[0].State)
|
||||
}
|
||||
}
|
||||
|
||||
// HE-10: Audit logging — all actions recorded.
|
||||
func TestHealingEngine_HE10_AuditLogging(t *testing.T) {
|
||||
mock := newMockExecutor()
|
||||
alertCh := make(chan HealthAlert, 10)
|
||||
|
||||
he := NewHealingEngine(alertCh, mock.execute, nil)
|
||||
he.RegisterStrategy(RestartComponentStrategy())
|
||||
|
||||
alertCh <- HealthAlert{
|
||||
Component: "comp",
|
||||
Severity: SeverityCritical,
|
||||
Metric: "quorum",
|
||||
SuggestedAction: "restart",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go he.Start(ctx)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
ops := he.RecentOperations(10)
|
||||
if len(ops) == 0 {
|
||||
t.Fatal("expected operation")
|
||||
}
|
||||
if len(ops[0].ActionsRun) == 0 {
|
||||
t.Error("expected action logs")
|
||||
}
|
||||
for _, al := range ops[0].ActionsRun {
|
||||
if al.StartedAt.IsZero() {
|
||||
t.Error("action log missing start time")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HE-11: Parallel healing — no race conditions.
|
||||
func TestHealingEngine_HE11_Parallel(t *testing.T) {
|
||||
mock := newMockExecutor()
|
||||
alertCh := make(chan HealthAlert, 100)
|
||||
|
||||
he := NewHealingEngine(alertCh, mock.execute, nil)
|
||||
for _, s := range DefaultStrategies() {
|
||||
he.RegisterStrategy(s)
|
||||
}
|
||||
|
||||
// Send many alerts concurrently.
|
||||
for i := 0; i < 10; i++ {
|
||||
alertCh <- HealthAlert{
|
||||
Component: fmt.Sprintf("comp-%d", i),
|
||||
Severity: SeverityCritical,
|
||||
Metric: "quorum",
|
||||
SuggestedAction: "restart",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go he.Start(ctx)
|
||||
time.Sleep(1 * time.Second)
|
||||
cancel()
|
||||
|
||||
// All 10 alerts processed (first gets an op, rest hit cooldown).
|
||||
ops := he.RecentOperations(100)
|
||||
if len(ops) == 0 {
|
||||
t.Fatal("expected at least 1 operation")
|
||||
}
|
||||
}
|
||||
|
||||
// HE-12: No matching strategy → no operation.
|
||||
func TestHealingEngine_HE12_NoStrategy(t *testing.T) {
|
||||
mock := newMockExecutor()
|
||||
alertCh := make(chan HealthAlert, 10)
|
||||
|
||||
he := NewHealingEngine(alertCh, mock.execute, nil)
|
||||
// No strategies registered.
|
||||
|
||||
alertCh <- HealthAlert{
|
||||
Component: "comp",
|
||||
Severity: SeverityCritical,
|
||||
Metric: "unknown_metric",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
go he.Start(ctx)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
ops := he.RecentOperations(10)
|
||||
if len(ops) != 0 {
|
||||
t.Errorf("expected 0 operations, got %d", len(ops))
|
||||
}
|
||||
}
|
||||
|
||||
// Test diagnosis (various root causes).
|
||||
func TestHealingEngine_Diagnosis(t *testing.T) {
|
||||
mock := newMockExecutor()
|
||||
he := NewHealingEngine(nil, mock.execute, nil)
|
||||
|
||||
tests := []struct {
|
||||
metric string
|
||||
current float64
|
||||
wantCause string
|
||||
}{
|
||||
{"memory", 95, "memory_exhaustion"},
|
||||
{"cpu", 95, "cpu_saturation"},
|
||||
{"error_rate", 10, "elevated_error_rate"},
|
||||
{"latency_p99", 200, "latency_degradation"},
|
||||
{"quorum", 0.3, "quorum_loss"},
|
||||
{"custom", 100, "threshold_breach_custom"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
alert := HealthAlert{
|
||||
Component: "test",
|
||||
Metric: tt.metric,
|
||||
Current: tt.current,
|
||||
}
|
||||
d := he.diagnose(alert)
|
||||
if d.RootCause != tt.wantCause {
|
||||
t.Errorf("metric=%s: expected %s, got %s", tt.metric, tt.wantCause, d.RootCause)
|
||||
}
|
||||
if d.Confidence <= 0 || d.Confidence > 1 {
|
||||
t.Errorf("metric=%s: invalid confidence %f", tt.metric, d.Confidence)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test DefaultStrategies returns 5 strategies.
|
||||
func TestDefaultStrategies(t *testing.T) {
|
||||
strategies := DefaultStrategies()
|
||||
if len(strategies) != 5 {
|
||||
t.Errorf("expected 5 strategies, got %d", len(strategies))
|
||||
}
|
||||
|
||||
ids := map[string]bool{}
|
||||
for _, s := range strategies {
|
||||
if ids[s.ID] {
|
||||
t.Errorf("duplicate strategy ID: %s", s.ID)
|
||||
}
|
||||
ids[s.ID] = true
|
||||
if s.MaxAttempts <= 0 {
|
||||
t.Errorf("strategy %s: invalid max_attempts %d", s.ID, s.MaxAttempts)
|
||||
}
|
||||
if s.Cooldown <= 0 {
|
||||
t.Errorf("strategy %s: invalid cooldown %v", s.ID, s.Cooldown)
|
||||
}
|
||||
if len(s.Actions) == 0 {
|
||||
t.Errorf("strategy %s: no actions defined", s.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test StrategyCount.
|
||||
func TestHealingEngine_StrategyCount(t *testing.T) {
|
||||
he := NewHealingEngine(nil, nil, nil)
|
||||
if he.StrategyCount() != 0 {
|
||||
t.Error("expected 0")
|
||||
}
|
||||
for _, s := range DefaultStrategies() {
|
||||
he.RegisterStrategy(s)
|
||||
}
|
||||
if he.StrategyCount() != 5 {
|
||||
t.Errorf("expected 5, got %d", he.StrategyCount())
|
||||
}
|
||||
}
|
||||
|
||||
// Test GetOperation.
|
||||
func TestHealingEngine_GetOperation(t *testing.T) {
|
||||
mock := newMockExecutor()
|
||||
alertCh := make(chan HealthAlert, 10)
|
||||
|
||||
he := NewHealingEngine(alertCh, mock.execute, nil)
|
||||
he.RegisterStrategy(RestartComponentStrategy())
|
||||
|
||||
alertCh <- HealthAlert{
|
||||
Component: "comp",
|
||||
Severity: SeverityCritical,
|
||||
Metric: "quorum",
|
||||
SuggestedAction: "restart",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go he.Start(ctx)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
op, ok := he.GetOperation("heal-1")
|
||||
if !ok {
|
||||
t.Fatal("expected operation heal-1")
|
||||
}
|
||||
if op.Component != "comp" {
|
||||
t.Errorf("expected comp, got %s", op.Component)
|
||||
}
|
||||
|
||||
_, ok = he.GetOperation("nonexistent")
|
||||
if ok {
|
||||
t.Error("expected not found for nonexistent")
|
||||
}
|
||||
}
|
||||
|
||||
// Test action OnError=continue.
|
||||
func TestHealingEngine_ActionContinueOnError(t *testing.T) {
|
||||
mock := newMockExecutor()
|
||||
mock.fail[ActionGracefulStop] = true // First action fails but marked continue.
|
||||
|
||||
alertCh := make(chan HealthAlert, 10)
|
||||
he := NewHealingEngine(alertCh, mock.execute, nil)
|
||||
he.RegisterStrategy(RestartComponentStrategy())
|
||||
|
||||
alertCh <- HealthAlert{
|
||||
Component: "comp",
|
||||
Severity: SeverityCritical,
|
||||
Metric: "quorum",
|
||||
SuggestedAction: "restart",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go he.Start(ctx)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
ops := he.RecentOperations(10)
|
||||
if len(ops) == 0 {
|
||||
t.Fatal("expected operation")
|
||||
}
|
||||
// Should still succeed because graceful_stop has OnError=continue.
|
||||
if ops[0].Result != ResultSuccess {
|
||||
t.Errorf("expected SUCCESS (continue on error), got %s", ops[0].Result)
|
||||
}
|
||||
}
|
||||
215
internal/application/resilience/healing_strategies.go
Normal file
215
internal/application/resilience/healing_strategies.go
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
package resilience
|
||||
|
||||
import "time"
|
||||
|
||||
// Built-in healing strategies per ТЗ §4.1.1.
|
||||
// These are registered at startup via HealingEngine.RegisterStrategy().
|
||||
|
||||
// DefaultStrategies returns the 5 built-in healing strategies.
|
||||
func DefaultStrategies() []HealingStrategy {
|
||||
return []HealingStrategy{
|
||||
RestartComponentStrategy(),
|
||||
RollbackConfigStrategy(),
|
||||
RecoverDatabaseStrategy(),
|
||||
RecoverRulesStrategy(),
|
||||
RecoverNetworkStrategy(),
|
||||
}
|
||||
}
|
||||
|
||||
// RestartComponentStrategy handles component crashes and offline states.
|
||||
// Trigger: component_offline OR component_critical, 2 consecutive failures within 5m.
|
||||
// Actions: graceful_stop → clear_temp → start → verify → notify.
|
||||
// Rollback: escalate to next strategy.
|
||||
func RestartComponentStrategy() HealingStrategy {
|
||||
return HealingStrategy{
|
||||
ID: "RESTART_COMPONENT",
|
||||
Name: "Component Restart",
|
||||
Trigger: TriggerCondition{
|
||||
Statuses: []ComponentStatus{StatusOffline, StatusCritical},
|
||||
ConsecutiveFailures: 2,
|
||||
WithinWindow: 5 * time.Minute,
|
||||
},
|
||||
Actions: []Action{
|
||||
{Type: ActionGracefulStop, Timeout: 10 * time.Second, OnError: "continue"},
|
||||
{Type: ActionClearTempFiles, Timeout: 5 * time.Second, OnError: "continue"},
|
||||
{Type: ActionStartComponent, Timeout: 30 * time.Second, OnError: "abort"},
|
||||
{Type: ActionVerifyHealth, Timeout: 60 * time.Second, OnError: "abort"},
|
||||
{Type: ActionNotifySOC, Timeout: 5 * time.Second, OnError: "continue",
|
||||
Params: map[string]interface{}{
|
||||
"severity": "INFO",
|
||||
"message": "Component restarted successfully",
|
||||
},
|
||||
},
|
||||
},
|
||||
Rollback: RollbackPlan{
|
||||
OnFailure: "escalate",
|
||||
Actions: []Action{
|
||||
{Type: ActionNotifyArchitect, Timeout: 5 * time.Second,
|
||||
Params: map[string]interface{}{
|
||||
"severity": "CRITICAL",
|
||||
"message": "Component restart failed after max attempts",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
MaxAttempts: 3,
|
||||
Cooldown: 5 * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
// RollbackConfigStrategy handles config tampering or validation failures.
|
||||
// Trigger: config_tampering_detected OR config_validation_failed.
|
||||
// Actions: freeze → verify_backup → rollback → restart → verify → notify.
|
||||
func RollbackConfigStrategy() HealingStrategy {
|
||||
return HealingStrategy{
|
||||
ID: "ROLLBACK_CONFIG",
|
||||
Name: "Configuration Rollback",
|
||||
Trigger: TriggerCondition{
|
||||
Metrics: []string{"config_tampering", "config_validation"},
|
||||
},
|
||||
Actions: []Action{
|
||||
{Type: ActionFreezeConfig, Timeout: 5 * time.Second, OnError: "abort"},
|
||||
{Type: ActionRollbackConfig, Timeout: 15 * time.Second, OnError: "abort"},
|
||||
{Type: ActionStartComponent, Timeout: 30 * time.Second, OnError: "rollback"},
|
||||
{Type: ActionVerifyConfig, Timeout: 10 * time.Second, OnError: "abort"},
|
||||
{Type: ActionNotifyArchitect, Timeout: 5 * time.Second, OnError: "continue",
|
||||
Params: map[string]interface{}{
|
||||
"severity": "WARNING",
|
||||
"message": "Config rolled back due to tampering",
|
||||
},
|
||||
},
|
||||
},
|
||||
Rollback: RollbackPlan{
|
||||
OnFailure: "enter_safe_mode",
|
||||
Actions: []Action{
|
||||
{Type: ActionEnterSafeMode, Timeout: 10 * time.Second},
|
||||
},
|
||||
},
|
||||
MaxAttempts: 1,
|
||||
Cooldown: 1 * time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
// RecoverDatabaseStrategy handles SQLite corruption.
|
||||
// Trigger: database_corruption OR sqlite_integrity_failed.
|
||||
// Actions: readonly → backup → restore → verify → resume → notify.
|
||||
func RecoverDatabaseStrategy() HealingStrategy {
|
||||
return HealingStrategy{
|
||||
ID: "RECOVER_DATABASE",
|
||||
Name: "Database Recovery",
|
||||
Trigger: TriggerCondition{
|
||||
Metrics: []string{"database_corruption", "sqlite_integrity"},
|
||||
},
|
||||
Actions: []Action{
|
||||
{Type: ActionSwitchReadOnly, Timeout: 5 * time.Second, OnError: "abort"},
|
||||
{Type: ActionBackupDB, Timeout: 30 * time.Second, OnError: "continue"},
|
||||
{Type: ActionRestoreSnapshot, Timeout: 60 * time.Second, OnError: "abort",
|
||||
Params: map[string]interface{}{
|
||||
"snapshot_age_max": "1h",
|
||||
},
|
||||
},
|
||||
{Type: ActionVerifyIntegrity, Timeout: 30 * time.Second, OnError: "abort"},
|
||||
{Type: ActionResumeWrites, Timeout: 5 * time.Second, OnError: "abort"},
|
||||
{Type: ActionNotifySOC, Timeout: 5 * time.Second, OnError: "continue",
|
||||
Params: map[string]interface{}{
|
||||
"severity": "WARNING",
|
||||
"message": "Database recovered from snapshot",
|
||||
},
|
||||
},
|
||||
},
|
||||
Rollback: RollbackPlan{
|
||||
OnFailure: "enter_lockdown",
|
||||
Actions: []Action{
|
||||
{Type: ActionEnterSafeMode, Timeout: 10 * time.Second},
|
||||
{Type: ActionNotifyArchitect, Timeout: 5 * time.Second,
|
||||
Params: map[string]interface{}{
|
||||
"severity": "CRITICAL",
|
||||
"message": "Database recovery failed",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
MaxAttempts: 2,
|
||||
Cooldown: 2 * time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
// RecoverRulesStrategy handles correlation rule poisoning.
|
||||
// Trigger: rule execution failure rate > 50%.
|
||||
// Actions: disable_suspicious → revert_baseline → verify → reload → notify.
|
||||
func RecoverRulesStrategy() HealingStrategy {
|
||||
return HealingStrategy{
|
||||
ID: "RECOVER_RULES",
|
||||
Name: "Rule Poisoning Defense",
|
||||
Trigger: TriggerCondition{
|
||||
Metrics: []string{"rule_execution_failure_rate", "correlation_rule_anomaly"},
|
||||
},
|
||||
Actions: []Action{
|
||||
{Type: ActionDisableRules, Timeout: 10 * time.Second, OnError: "abort",
|
||||
Params: map[string]interface{}{
|
||||
"criteria": "failure_rate > 80%",
|
||||
},
|
||||
},
|
||||
{Type: ActionRevertRules, Timeout: 15 * time.Second, OnError: "abort"},
|
||||
{Type: ActionReloadEngine, Timeout: 30 * time.Second, OnError: "abort"},
|
||||
{Type: ActionVerifyHealth, Timeout: 30 * time.Second, OnError: "continue"},
|
||||
{Type: ActionNotifyArchitect, Timeout: 5 * time.Second, OnError: "continue",
|
||||
Params: map[string]interface{}{
|
||||
"severity": "WARNING",
|
||||
"message": "Rules recovered from baseline",
|
||||
},
|
||||
},
|
||||
},
|
||||
Rollback: RollbackPlan{
|
||||
OnFailure: "disable_correlation",
|
||||
},
|
||||
MaxAttempts: 2,
|
||||
Cooldown: 4 * time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
// RecoverNetworkStrategy handles network partition or mTLS cert expiry.
|
||||
// Trigger: network_partition_detected OR mTLS_cert_expired.
|
||||
// Actions: isolate → regen_certs → verify → restore → notify.
|
||||
func RecoverNetworkStrategy() HealingStrategy {
|
||||
return HealingStrategy{
|
||||
ID: "RECOVER_NETWORK",
|
||||
Name: "Network Isolation Recovery",
|
||||
Trigger: TriggerCondition{
|
||||
Metrics: []string{"network_partition", "mtls_cert_expiry"},
|
||||
},
|
||||
Actions: []Action{
|
||||
{Type: ActionIsolateNetwork, Timeout: 5 * time.Second, OnError: "abort",
|
||||
Params: map[string]interface{}{
|
||||
"scope": "external_only",
|
||||
},
|
||||
},
|
||||
{Type: ActionRegenCerts, Timeout: 30 * time.Second, OnError: "abort",
|
||||
Params: map[string]interface{}{
|
||||
"validity": "24h",
|
||||
},
|
||||
},
|
||||
{Type: ActionVerifyHealth, Timeout: 30 * time.Second, OnError: "rollback"},
|
||||
{Type: ActionRestoreNetwork, Timeout: 10 * time.Second, OnError: "abort"},
|
||||
{Type: ActionNotifySOC, Timeout: 5 * time.Second, OnError: "continue",
|
||||
Params: map[string]interface{}{
|
||||
"severity": "INFO",
|
||||
"message": "Network connectivity restored",
|
||||
},
|
||||
},
|
||||
},
|
||||
Rollback: RollbackPlan{
|
||||
OnFailure: "maintain_isolation",
|
||||
Actions: []Action{
|
||||
{Type: ActionNotifyArchitect, Timeout: 5 * time.Second,
|
||||
Params: map[string]interface{}{
|
||||
"severity": "CRITICAL",
|
||||
"message": "Network recovery failed, maintaining isolation",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
MaxAttempts: 3,
|
||||
Cooldown: 1 * time.Hour,
|
||||
}
|
||||
}
|
||||
445
internal/application/resilience/health_monitor.go
Normal file
445
internal/application/resilience/health_monitor.go
Normal file
|
|
@ -0,0 +1,445 @@
|
|||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ComponentStatus defines the health state of a monitored component.
|
||||
type ComponentStatus string
|
||||
|
||||
const (
|
||||
StatusHealthy ComponentStatus = "HEALTHY"
|
||||
StatusDegraded ComponentStatus = "DEGRADED"
|
||||
StatusCritical ComponentStatus = "CRITICAL"
|
||||
StatusOffline ComponentStatus = "OFFLINE"
|
||||
)
|
||||
|
||||
// AlertSeverity defines the severity of a health alert.
|
||||
type AlertSeverity string
|
||||
|
||||
const (
|
||||
SeverityInfo AlertSeverity = "INFO"
|
||||
SeverityWarning AlertSeverity = "WARNING"
|
||||
SeverityCritical AlertSeverity = "CRITICAL"
|
||||
)
|
||||
|
||||
// OverallStatus aggregates component statuses into a system-wide status.
|
||||
type OverallStatus string
|
||||
|
||||
const (
|
||||
OverallHealthy OverallStatus = "HEALTHY"
|
||||
OverallDegraded OverallStatus = "DEGRADED"
|
||||
OverallCritical OverallStatus = "CRITICAL"
|
||||
)
|
||||
|
||||
// Default intervals per ТЗ §3.1.2.
|
||||
const (
|
||||
MetricsCollectionInterval = 10 * time.Second
|
||||
HealthCheckInterval = 30 * time.Second
|
||||
QuorumValidationInterval = 60 * time.Second
|
||||
|
||||
// AnomalyZScoreThreshold — Z > 3.0 = anomaly (99.7% confidence).
|
||||
AnomalyZScoreThreshold = 3.0
|
||||
|
||||
// QuorumThreshold — 2/3 must be healthy.
|
||||
QuorumThreshold = 0.66
|
||||
|
||||
// MaxConsecutiveFailures before marking CRITICAL.
|
||||
MaxConsecutiveFailures = 3
|
||||
)
|
||||
|
||||
// ComponentConfig defines monitoring thresholds for a component.
|
||||
type ComponentConfig struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"` // go_binary, c_binary, c_kernel_module
|
||||
Thresholds map[string]float64 `json:"thresholds"`
|
||||
// Whether threshold is an upper bound (true) or lower bound (false).
|
||||
ThresholdIsMax map[string]bool `json:"threshold_is_max"`
|
||||
}
|
||||
|
||||
// ComponentHealth tracks the health state of a single component.
|
||||
type ComponentHealth struct {
|
||||
Name string `json:"name"`
|
||||
Status ComponentStatus `json:"status"`
|
||||
Metrics map[string]float64 `json:"metrics"`
|
||||
LastCheck time.Time `json:"last_check"`
|
||||
Consecutive int `json:"consecutive_failures"`
|
||||
Config ComponentConfig `json:"-"`
|
||||
}
|
||||
|
||||
// HealthAlert represents a detected health anomaly.
|
||||
type HealthAlert struct {
|
||||
Component string `json:"component"`
|
||||
Severity AlertSeverity `json:"severity"`
|
||||
Metric string `json:"metric"`
|
||||
Current float64 `json:"current"`
|
||||
Threshold float64 `json:"threshold"`
|
||||
ZScore float64 `json:"z_score,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
SuggestedAction string `json:"suggested_action"`
|
||||
}
|
||||
|
||||
// HealthResponse is the API response for GET /api/v1/resilience/health.
|
||||
type HealthResponse struct {
|
||||
OverallStatus OverallStatus `json:"overall_status"`
|
||||
Components []ComponentHealth `json:"components"`
|
||||
QuorumValid bool `json:"quorum_valid"`
|
||||
LastCheck time.Time `json:"last_check"`
|
||||
AnomaliesDetected []HealthAlert `json:"anomalies_detected"`
|
||||
}
|
||||
|
||||
// MetricsCollector is the interface for collecting metrics from components.
|
||||
// Implementations can use /healthz endpoints, /metrics, or runtime stats.
|
||||
type MetricsCollector interface {
|
||||
Collect(ctx context.Context, component string) (map[string]float64, error)
|
||||
}
|
||||
|
||||
// HealthMonitor is the L1 Self-Monitoring orchestrator.
|
||||
// It collects metrics, runs anomaly detection, validates quorum,
|
||||
// and emits HealthAlerts to the alert bus.
|
||||
type HealthMonitor struct {
|
||||
mu sync.RWMutex
|
||||
components map[string]*ComponentHealth
|
||||
metricsDB *MetricsDB
|
||||
alertBus chan HealthAlert
|
||||
collector MetricsCollector
|
||||
logger *slog.Logger
|
||||
|
||||
// anomalyWindow is the baseline window for Z-score calculation.
|
||||
anomalyWindow time.Duration
|
||||
}
|
||||
|
||||
// NewHealthMonitor creates a new health monitor.
|
||||
func NewHealthMonitor(collector MetricsCollector, alertBufSize int) *HealthMonitor {
|
||||
if alertBufSize <= 0 {
|
||||
alertBufSize = 100
|
||||
}
|
||||
return &HealthMonitor{
|
||||
components: make(map[string]*ComponentHealth),
|
||||
metricsDB: NewMetricsDB(DefaultMetricsWindow, DefaultMetricsMaxSize),
|
||||
alertBus: make(chan HealthAlert, alertBufSize),
|
||||
collector: collector,
|
||||
logger: slog.Default().With("component", "sarl-health-monitor"),
|
||||
anomalyWindow: 24 * time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterComponent adds a component to be monitored.
|
||||
func (hm *HealthMonitor) RegisterComponent(config ComponentConfig) {
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
|
||||
hm.components[config.Name] = &ComponentHealth{
|
||||
Name: config.Name,
|
||||
Status: StatusHealthy,
|
||||
Metrics: make(map[string]float64),
|
||||
Config: config,
|
||||
}
|
||||
hm.logger.Info("component registered", "name", config.Name, "type", config.Type)
|
||||
}
|
||||
|
||||
// AlertBus returns the channel for consuming health alerts.
|
||||
func (hm *HealthMonitor) AlertBus() <-chan HealthAlert {
|
||||
return hm.alertBus
|
||||
}
|
||||
|
||||
// Start begins the monitoring loops. Blocks until ctx is cancelled.
|
||||
func (hm *HealthMonitor) Start(ctx context.Context) {
|
||||
hm.logger.Info("health monitor started")
|
||||
|
||||
metricsTicker := time.NewTicker(MetricsCollectionInterval)
|
||||
healthTicker := time.NewTicker(HealthCheckInterval)
|
||||
quorumTicker := time.NewTicker(QuorumValidationInterval)
|
||||
defer metricsTicker.Stop()
|
||||
defer healthTicker.Stop()
|
||||
defer quorumTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
hm.logger.Info("health monitor stopped")
|
||||
return
|
||||
case <-metricsTicker.C:
|
||||
hm.collectMetrics(ctx)
|
||||
case <-healthTicker.C:
|
||||
hm.checkHealth()
|
||||
case <-quorumTicker.C:
|
||||
hm.validateQuorum()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectMetrics gathers metrics from all registered components.
|
||||
func (hm *HealthMonitor) collectMetrics(ctx context.Context) {
|
||||
hm.mu.RLock()
|
||||
names := make([]string, 0, len(hm.components))
|
||||
for name := range hm.components {
|
||||
names = append(names, name)
|
||||
}
|
||||
hm.mu.RUnlock()
|
||||
|
||||
for _, name := range names {
|
||||
metrics, err := hm.collector.Collect(ctx, name)
|
||||
if err != nil {
|
||||
hm.logger.Warn("metrics collection failed", "component", name, "error", err)
|
||||
hm.mu.Lock()
|
||||
if comp, ok := hm.components[name]; ok {
|
||||
comp.Consecutive++
|
||||
}
|
||||
hm.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
hm.mu.Lock()
|
||||
comp, ok := hm.components[name]
|
||||
if ok {
|
||||
comp.Metrics = metrics
|
||||
comp.LastCheck = time.Now()
|
||||
// Store each metric in time-series DB.
|
||||
for metric, value := range metrics {
|
||||
hm.metricsDB.AddDataPoint(name, metric, value)
|
||||
}
|
||||
}
|
||||
hm.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// checkHealth evaluates each component against thresholds and anomalies.
|
||||
func (hm *HealthMonitor) checkHealth() {
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
|
||||
for _, comp := range hm.components {
|
||||
alerts := hm.evaluateComponent(comp)
|
||||
for _, alert := range alerts {
|
||||
hm.emitAlert(alert)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// evaluateComponent checks a single component's metrics against thresholds
|
||||
// and runs Z-score anomaly detection. Returns any generated alerts.
|
||||
func (hm *HealthMonitor) evaluateComponent(comp *ComponentHealth) []HealthAlert {
|
||||
var alerts []HealthAlert
|
||||
breached := false
|
||||
|
||||
for metric, value := range comp.Metrics {
|
||||
threshold, hasThreshold := comp.Config.Thresholds[metric]
|
||||
if !hasThreshold {
|
||||
continue
|
||||
}
|
||||
|
||||
isMax := comp.Config.ThresholdIsMax[metric]
|
||||
var exceeded bool
|
||||
if isMax {
|
||||
exceeded = value > threshold
|
||||
} else {
|
||||
exceeded = value < threshold
|
||||
}
|
||||
|
||||
if exceeded {
|
||||
breached = true
|
||||
action := "restart"
|
||||
if metric == "error_rate" || metric == "latency_p99" {
|
||||
action = "investigate"
|
||||
}
|
||||
|
||||
alerts = append(alerts, HealthAlert{
|
||||
Component: comp.Name,
|
||||
Severity: SeverityWarning,
|
||||
Metric: metric,
|
||||
Current: value,
|
||||
Threshold: threshold,
|
||||
Timestamp: time.Now(),
|
||||
SuggestedAction: action,
|
||||
})
|
||||
}
|
||||
|
||||
// Z-score anomaly detection.
|
||||
baseline := hm.metricsDB.GetBaseline(comp.Name, metric, hm.anomalyWindow)
|
||||
if IsAnomaly(value, baseline, AnomalyZScoreThreshold) {
|
||||
zscore := CalculateZScore(value, baseline)
|
||||
alerts = append(alerts, HealthAlert{
|
||||
Component: comp.Name,
|
||||
Severity: SeverityCritical,
|
||||
Metric: metric,
|
||||
Current: value,
|
||||
Threshold: baseline.Mean + AnomalyZScoreThreshold*baseline.StdDev,
|
||||
ZScore: zscore,
|
||||
Timestamp: time.Now(),
|
||||
SuggestedAction: fmt.Sprintf("anomaly detected (Z=%.2f), investigate %s", zscore, metric),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Update component status.
|
||||
if breached {
|
||||
comp.Consecutive++
|
||||
if comp.Consecutive >= MaxConsecutiveFailures {
|
||||
comp.Status = StatusCritical
|
||||
} else {
|
||||
comp.Status = StatusDegraded
|
||||
}
|
||||
} else {
|
||||
comp.Consecutive = 0
|
||||
comp.Status = StatusHealthy
|
||||
}
|
||||
|
||||
return alerts
|
||||
}
|
||||
|
||||
// emitAlert sends an alert to the bus (non-blocking).
|
||||
func (hm *HealthMonitor) emitAlert(alert HealthAlert) {
|
||||
select {
|
||||
case hm.alertBus <- alert:
|
||||
hm.logger.Warn("health alert emitted",
|
||||
"component", alert.Component,
|
||||
"severity", alert.Severity,
|
||||
"metric", alert.Metric,
|
||||
"current", alert.Current,
|
||||
"threshold", alert.Threshold,
|
||||
)
|
||||
default:
|
||||
hm.logger.Error("alert bus full, dropping alert",
|
||||
"component", alert.Component,
|
||||
"metric", alert.Metric,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// validateQuorum checks if 2/3 of components are healthy.
|
||||
func (hm *HealthMonitor) validateQuorum() {
|
||||
hm.mu.RLock()
|
||||
defer hm.mu.RUnlock()
|
||||
|
||||
if len(hm.components) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
valid := ValidateQuorum(hm.componentStatuses())
|
||||
|
||||
if !valid {
|
||||
hm.logger.Error("QUORUM LOST — entering degraded state",
|
||||
"healthy_ratio", hm.healthyRatio(),
|
||||
"threshold", QuorumThreshold,
|
||||
)
|
||||
hm.emitAlert(HealthAlert{
|
||||
Component: "system",
|
||||
Severity: SeverityCritical,
|
||||
Metric: "quorum",
|
||||
Current: hm.healthyRatio(),
|
||||
Threshold: QuorumThreshold,
|
||||
Timestamp: time.Now(),
|
||||
SuggestedAction: "activate safe mode",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateQuorum checks if the healthy ratio meets the 2/3 threshold.
|
||||
func ValidateQuorum(statuses map[string]ComponentStatus) bool {
|
||||
if len(statuses) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
healthy := 0
|
||||
for _, status := range statuses {
|
||||
if status == StatusHealthy {
|
||||
healthy++
|
||||
}
|
||||
}
|
||||
return float64(healthy)/float64(len(statuses)) >= QuorumThreshold
|
||||
}
|
||||
|
||||
// componentStatuses returns current status map (caller must hold RLock).
|
||||
func (hm *HealthMonitor) componentStatuses() map[string]ComponentStatus {
|
||||
statuses := make(map[string]ComponentStatus, len(hm.components))
|
||||
for name, comp := range hm.components {
|
||||
statuses[name] = comp.Status
|
||||
}
|
||||
return statuses
|
||||
}
|
||||
|
||||
// healthyRatio returns the fraction of healthy components (caller must hold RLock).
|
||||
func (hm *HealthMonitor) healthyRatio() float64 {
|
||||
if len(hm.components) == 0 {
|
||||
return 0
|
||||
}
|
||||
healthy := 0
|
||||
for _, comp := range hm.components {
|
||||
if comp.Status == StatusHealthy {
|
||||
healthy++
|
||||
}
|
||||
}
|
||||
return float64(healthy) / float64(len(hm.components))
|
||||
}
|
||||
|
||||
// GetHealth returns a snapshot of the entire system health.
|
||||
func (hm *HealthMonitor) GetHealth() HealthResponse {
|
||||
hm.mu.RLock()
|
||||
defer hm.mu.RUnlock()
|
||||
|
||||
components := make([]ComponentHealth, 0, len(hm.components))
|
||||
for _, comp := range hm.components {
|
||||
cp := *comp
|
||||
// Deep copy metrics.
|
||||
cp.Metrics = make(map[string]float64, len(comp.Metrics))
|
||||
for k, v := range comp.Metrics {
|
||||
cp.Metrics[k] = v
|
||||
}
|
||||
components = append(components, cp)
|
||||
}
|
||||
|
||||
overall := OverallHealthy
|
||||
for _, comp := range components {
|
||||
switch comp.Status {
|
||||
case StatusCritical, StatusOffline:
|
||||
overall = OverallCritical
|
||||
case StatusDegraded:
|
||||
if overall != OverallCritical {
|
||||
overall = OverallDegraded
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return HealthResponse{
|
||||
OverallStatus: overall,
|
||||
Components: components,
|
||||
QuorumValid: ValidateQuorum(hm.componentStatuses()),
|
||||
LastCheck: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// SetComponentStatus manually sets a component's status (for testing/override).
|
||||
func (hm *HealthMonitor) SetComponentStatus(name string, status ComponentStatus) {
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
|
||||
if comp, ok := hm.components[name]; ok {
|
||||
comp.Status = status
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateMetrics manually updates a component's metrics (for testing/override).
|
||||
func (hm *HealthMonitor) UpdateMetrics(name string, metrics map[string]float64) {
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
|
||||
if comp, ok := hm.components[name]; ok {
|
||||
comp.Metrics = metrics
|
||||
comp.LastCheck = time.Now()
|
||||
for metric, value := range metrics {
|
||||
hm.metricsDB.AddDataPoint(name, metric, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ComponentCount returns the number of registered components.
|
||||
func (hm *HealthMonitor) ComponentCount() int {
|
||||
hm.mu.RLock()
|
||||
defer hm.mu.RUnlock()
|
||||
return len(hm.components)
|
||||
}
|
||||
499
internal/application/resilience/health_monitor_test.go
Normal file
499
internal/application/resilience/health_monitor_test.go
Normal file
|
|
@ -0,0 +1,499 @@
|
|||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- MetricsDB Tests ---
|
||||
|
||||
func TestRingBuffer_AddAndAll(t *testing.T) {
|
||||
rb := newRingBuffer(5)
|
||||
now := time.Now()
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
rb.Add(DataPoint{Timestamp: now.Add(time.Duration(i) * time.Second), Value: float64(i)})
|
||||
}
|
||||
|
||||
if rb.Len() != 3 {
|
||||
t.Fatalf("expected 3, got %d", rb.Len())
|
||||
}
|
||||
|
||||
all := rb.All()
|
||||
if len(all) != 3 {
|
||||
t.Fatalf("expected 3 points, got %d", len(all))
|
||||
}
|
||||
for i, dp := range all {
|
||||
if dp.Value != float64(i) {
|
||||
t.Errorf("point %d: expected %f, got %f", i, float64(i), dp.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBuffer_Wrap(t *testing.T) {
|
||||
rb := newRingBuffer(3)
|
||||
now := time.Now()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
rb.Add(DataPoint{Timestamp: now.Add(time.Duration(i) * time.Second), Value: float64(i)})
|
||||
}
|
||||
|
||||
if rb.Len() != 3 {
|
||||
t.Fatalf("expected 3 (buffer size), got %d", rb.Len())
|
||||
}
|
||||
|
||||
all := rb.All()
|
||||
// Should contain values 2, 3, 4 (oldest 0, 1 overwritten).
|
||||
expected := []float64{2, 3, 4}
|
||||
for i, dp := range all {
|
||||
if dp.Value != expected[i] {
|
||||
t.Errorf("point %d: expected %f, got %f", i, expected[i], dp.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsDB_AddAndBaseline(t *testing.T) {
|
||||
db := NewMetricsDB(time.Hour, 100)
|
||||
for i := 0; i < 20; i++ {
|
||||
db.AddDataPoint("soc-ingest", "cpu", 30.0+float64(i%5))
|
||||
}
|
||||
|
||||
baseline := db.GetBaseline("soc-ingest", "cpu", time.Hour)
|
||||
if baseline.Count != 20 {
|
||||
t.Fatalf("expected 20 points, got %d", baseline.Count)
|
||||
}
|
||||
if baseline.Mean < 30 || baseline.Mean > 35 {
|
||||
t.Errorf("mean out of expected range: %f", baseline.Mean)
|
||||
}
|
||||
if baseline.StdDev == 0 {
|
||||
t.Error("expected non-zero stddev")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsDB_EmptyBaseline(t *testing.T) {
|
||||
db := NewMetricsDB(time.Hour, 100)
|
||||
baseline := db.GetBaseline("nonexistent", "cpu", time.Hour)
|
||||
if baseline.Count != 0 {
|
||||
t.Errorf("expected 0 count for nonexistent, got %d", baseline.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateZScore(t *testing.T) {
|
||||
baseline := Baseline{Mean: 30.0, StdDev: 5.0, Count: 100}
|
||||
|
||||
// Normal value (Z = 1.0).
|
||||
z := CalculateZScore(35.0, baseline)
|
||||
if math.Abs(z-1.0) > 0.01 {
|
||||
t.Errorf("expected Z≈1.0, got %f", z)
|
||||
}
|
||||
|
||||
// Anomalous value (Z = 4.0).
|
||||
z = CalculateZScore(50.0, baseline)
|
||||
if math.Abs(z-4.0) > 0.01 {
|
||||
t.Errorf("expected Z≈4.0, got %f", z)
|
||||
}
|
||||
|
||||
// Insufficient data → 0.
|
||||
z = CalculateZScore(50.0, Baseline{Mean: 30, StdDev: 5, Count: 5})
|
||||
if z != 0 {
|
||||
t.Errorf("expected 0 for insufficient data, got %f", z)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAnomaly(t *testing.T) {
|
||||
baseline := Baseline{Mean: 30.0, StdDev: 5.0, Count: 100}
|
||||
|
||||
if IsAnomaly(35.0, baseline, 3.0) {
|
||||
t.Error("35 should not be anomaly (Z=1.0)")
|
||||
}
|
||||
if !IsAnomaly(50.0, baseline, 3.0) {
|
||||
t.Error("50 should be anomaly (Z=4.0)")
|
||||
}
|
||||
if !IsAnomaly(10.0, baseline, 3.0) {
|
||||
t.Error("10 should be anomaly (Z=-4.0)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsDB_Purge(t *testing.T) {
|
||||
db := NewMetricsDB(100*time.Millisecond, 100)
|
||||
db.AddDataPoint("comp", "cpu", 50)
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
db.AddDataPoint("comp", "cpu", 60)
|
||||
|
||||
removed := db.Purge()
|
||||
if removed != 1 {
|
||||
t.Errorf("expected 1 purged, got %d", removed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsDB_GetRecent(t *testing.T) {
|
||||
db := NewMetricsDB(time.Hour, 100)
|
||||
for i := 0; i < 10; i++ {
|
||||
db.AddDataPoint("comp", "mem", float64(i*10))
|
||||
}
|
||||
|
||||
recent := db.GetRecent("comp", "mem", 3)
|
||||
if len(recent) != 3 {
|
||||
t.Fatalf("expected 3 recent, got %d", len(recent))
|
||||
}
|
||||
// Should be last 3: 70, 80, 90.
|
||||
if recent[0].Value != 70 || recent[2].Value != 90 {
|
||||
t.Errorf("unexpected recent values: %v", recent)
|
||||
}
|
||||
}
|
||||
|
||||
// --- MockCollector for HealthMonitor tests ---
|
||||
|
||||
type mockCollector struct {
|
||||
results map[string]map[string]float64
|
||||
errors map[string]error
|
||||
}
|
||||
|
||||
func (m *mockCollector) Collect(_ context.Context, component string) (map[string]float64, error) {
|
||||
if err, ok := m.errors[component]; ok && err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if metrics, ok := m.results[component]; ok {
|
||||
return metrics, nil
|
||||
}
|
||||
return map[string]float64{}, nil
|
||||
}
|
||||
|
||||
// --- HealthMonitor Tests ---
|
||||
|
||||
// HM-01: Normal health check — all HEALTHY.
|
||||
func TestHealthMonitor_HM01_AllHealthy(t *testing.T) {
|
||||
hm := NewHealthMonitor(&mockCollector{}, 10)
|
||||
registerTestComponents(hm, 6)
|
||||
|
||||
health := hm.GetHealth()
|
||||
if health.OverallStatus != OverallHealthy {
|
||||
t.Errorf("expected HEALTHY, got %s", health.OverallStatus)
|
||||
}
|
||||
if !health.QuorumValid {
|
||||
t.Error("expected quorum valid")
|
||||
}
|
||||
if len(health.Components) != 6 {
|
||||
t.Errorf("expected 6 components, got %d", len(health.Components))
|
||||
}
|
||||
}
|
||||
|
||||
// HM-02: Single component DEGRADED.
|
||||
func TestHealthMonitor_HM02_SingleDegraded(t *testing.T) {
|
||||
hm := NewHealthMonitor(&mockCollector{}, 10)
|
||||
registerTestComponents(hm, 6)
|
||||
hm.SetComponentStatus("comp-0", StatusDegraded)
|
||||
|
||||
health := hm.GetHealth()
|
||||
if health.OverallStatus != OverallDegraded {
|
||||
t.Errorf("expected DEGRADED, got %s", health.OverallStatus)
|
||||
}
|
||||
if !health.QuorumValid {
|
||||
t.Error("expected quorum still valid with 5/6 healthy")
|
||||
}
|
||||
}
|
||||
|
||||
// HM-03: Multiple components CRITICAL → quorum lost.
|
||||
func TestHealthMonitor_HM03_MultipleCritical(t *testing.T) {
|
||||
hm := NewHealthMonitor(&mockCollector{}, 10)
|
||||
registerTestComponents(hm, 6)
|
||||
hm.SetComponentStatus("comp-0", StatusCritical)
|
||||
hm.SetComponentStatus("comp-1", StatusCritical)
|
||||
hm.SetComponentStatus("comp-2", StatusCritical)
|
||||
|
||||
health := hm.GetHealth()
|
||||
if health.OverallStatus != OverallCritical {
|
||||
t.Errorf("expected CRITICAL, got %s", health.OverallStatus)
|
||||
}
|
||||
if health.QuorumValid {
|
||||
t.Error("expected quorum INVALID with 3/6 critical")
|
||||
}
|
||||
}
|
||||
|
||||
// HM-04: Anomaly detection (CPU spike).
|
||||
func TestHealthMonitor_HM04_CPUAnomaly(t *testing.T) {
|
||||
hm := NewHealthMonitor(&mockCollector{}, 100)
|
||||
hm.RegisterComponent(ComponentConfig{
|
||||
Name: "soc-ingest",
|
||||
Type: "go_binary",
|
||||
Thresholds: map[string]float64{"cpu": 80},
|
||||
ThresholdIsMax: map[string]bool{"cpu": true},
|
||||
})
|
||||
|
||||
// Build baseline of normal CPU (30%).
|
||||
for i := 0; i < 50; i++ {
|
||||
hm.metricsDB.AddDataPoint("soc-ingest", "cpu", 30.0)
|
||||
}
|
||||
|
||||
// Spike to 95%.
|
||||
hm.UpdateMetrics("soc-ingest", map[string]float64{"cpu": 95.0})
|
||||
hm.checkHealth()
|
||||
|
||||
// Should have alert(s).
|
||||
select {
|
||||
case alert := <-hm.alertBus:
|
||||
if alert.Component != "soc-ingest" {
|
||||
t.Errorf("expected soc-ingest, got %s", alert.Component)
|
||||
}
|
||||
if alert.Metric != "cpu" {
|
||||
t.Errorf("expected cpu metric, got %s", alert.Metric)
|
||||
}
|
||||
default:
|
||||
t.Error("expected alert for CPU spike")
|
||||
}
|
||||
}
|
||||
|
||||
// HM-05: Memory leak detection.
|
||||
func TestHealthMonitor_HM05_MemoryLeak(t *testing.T) {
|
||||
hm := NewHealthMonitor(&mockCollector{}, 100)
|
||||
hm.RegisterComponent(ComponentConfig{
|
||||
Name: "soc-correlate",
|
||||
Type: "go_binary",
|
||||
Thresholds: map[string]float64{"memory": 90},
|
||||
ThresholdIsMax: map[string]bool{"memory": true},
|
||||
})
|
||||
|
||||
// Build baseline of normal memory (40%).
|
||||
for i := 0; i < 50; i++ {
|
||||
hm.metricsDB.AddDataPoint("soc-correlate", "memory", 40.0)
|
||||
}
|
||||
|
||||
// Memory spike to 95%.
|
||||
hm.UpdateMetrics("soc-correlate", map[string]float64{"memory": 95.0})
|
||||
hm.checkHealth()
|
||||
|
||||
select {
|
||||
case alert := <-hm.alertBus:
|
||||
if alert.Metric != "memory" {
|
||||
t.Errorf("expected memory metric, got %s", alert.Metric)
|
||||
}
|
||||
default:
|
||||
t.Error("expected alert for memory spike")
|
||||
}
|
||||
}
|
||||
|
||||
// HM-06: Quorum validation failure.
|
||||
func TestHealthMonitor_HM06_QuorumFailure(t *testing.T) {
|
||||
statuses := map[string]ComponentStatus{
|
||||
"a": StatusOffline,
|
||||
"b": StatusOffline,
|
||||
"c": StatusOffline,
|
||||
"d": StatusOffline,
|
||||
"e": StatusHealthy,
|
||||
"f": StatusHealthy,
|
||||
}
|
||||
if ValidateQuorum(statuses) {
|
||||
t.Error("expected quorum invalid with 4/6 offline")
|
||||
}
|
||||
}
|
||||
|
||||
// HM-06b: Quorum validation success (edge case: exactly 2/3).
|
||||
func TestHealthMonitor_HM06b_QuorumEdge(t *testing.T) {
|
||||
statuses := map[string]ComponentStatus{
|
||||
"a": StatusHealthy,
|
||||
"b": StatusHealthy,
|
||||
"c": StatusCritical,
|
||||
}
|
||||
if !ValidateQuorum(statuses) {
|
||||
t.Error("expected quorum valid with 2/3 healthy (exact threshold)")
|
||||
}
|
||||
}
|
||||
|
||||
// HM-06c: Empty quorum.
|
||||
func TestHealthMonitor_HM06c_EmptyQuorum(t *testing.T) {
|
||||
if ValidateQuorum(map[string]ComponentStatus{}) {
|
||||
t.Error("expected quorum invalid with 0 components")
|
||||
}
|
||||
}
|
||||
|
||||
// HM-07: Metrics collection (no data loss).
|
||||
func TestHealthMonitor_HM07_MetricsCollection(t *testing.T) {
|
||||
collector := &mockCollector{
|
||||
results: map[string]map[string]float64{
|
||||
"comp-0": {"cpu": 25, "memory": 40},
|
||||
},
|
||||
}
|
||||
hm := NewHealthMonitor(collector, 10)
|
||||
hm.RegisterComponent(ComponentConfig{Name: "comp-0", Type: "go_binary"})
|
||||
|
||||
hm.collectMetrics(context.Background())
|
||||
|
||||
hm.mu.RLock()
|
||||
comp := hm.components["comp-0"]
|
||||
hm.mu.RUnlock()
|
||||
|
||||
if comp.Metrics["cpu"] != 25 {
|
||||
t.Errorf("expected cpu=25, got %f", comp.Metrics["cpu"])
|
||||
}
|
||||
if comp.Metrics["memory"] != 40 {
|
||||
t.Errorf("expected memory=40, got %f", comp.Metrics["memory"])
|
||||
}
|
||||
}
|
||||
|
||||
// HM-07b: Collection error increments consecutive failures.
|
||||
func TestHealthMonitor_HM07b_CollectionError(t *testing.T) {
|
||||
collector := &mockCollector{
|
||||
errors: map[string]error{
|
||||
"comp-0": fmt.Errorf("connection refused"),
|
||||
},
|
||||
}
|
||||
hm := NewHealthMonitor(collector, 10)
|
||||
hm.RegisterComponent(ComponentConfig{Name: "comp-0", Type: "go_binary"})
|
||||
|
||||
hm.collectMetrics(context.Background())
|
||||
|
||||
hm.mu.RLock()
|
||||
comp := hm.components["comp-0"]
|
||||
hm.mu.RUnlock()
|
||||
|
||||
if comp.Consecutive != 1 {
|
||||
t.Errorf("expected 1 consecutive failure, got %d", comp.Consecutive)
|
||||
}
|
||||
}
|
||||
|
||||
// HM-08: Alert bus fan-out (non-blocking).
|
||||
func TestHealthMonitor_HM08_AlertBusFanOut(t *testing.T) {
|
||||
hm := NewHealthMonitor(&mockCollector{}, 5)
|
||||
hm.RegisterComponent(ComponentConfig{
|
||||
Name: "comp",
|
||||
Type: "go_binary",
|
||||
Thresholds: map[string]float64{"cpu": 50},
|
||||
ThresholdIsMax: map[string]bool{"cpu": true},
|
||||
})
|
||||
|
||||
// Fill alert bus.
|
||||
for i := 0; i < 5; i++ {
|
||||
hm.alertBus <- HealthAlert{Component: fmt.Sprintf("test-%d", i)}
|
||||
}
|
||||
|
||||
// Emit one more — should be dropped (non-blocking).
|
||||
hm.emitAlert(HealthAlert{Component: "overflow"})
|
||||
// No panic = success.
|
||||
}
|
||||
|
||||
// Test GetHealth returns a deep copy.
|
||||
func TestHealthMonitor_GetHealthDeepCopy(t *testing.T) {
|
||||
hm := NewHealthMonitor(&mockCollector{}, 10)
|
||||
hm.RegisterComponent(ComponentConfig{Name: "test", Type: "go_binary"})
|
||||
hm.UpdateMetrics("test", map[string]float64{"cpu": 50})
|
||||
|
||||
health := hm.GetHealth()
|
||||
health.Components[0].Metrics["cpu"] = 999
|
||||
|
||||
// Original should be unchanged.
|
||||
hm.mu.RLock()
|
||||
original := hm.components["test"].Metrics["cpu"]
|
||||
hm.mu.RUnlock()
|
||||
|
||||
if original != 50 {
|
||||
t.Errorf("deep copy failed: original modified to %f", original)
|
||||
}
|
||||
}
|
||||
|
||||
// Test threshold breach transitions status to DEGRADED then CRITICAL.
|
||||
func TestHealthMonitor_StatusTransitions(t *testing.T) {
|
||||
hm := NewHealthMonitor(&mockCollector{}, 100)
|
||||
hm.RegisterComponent(ComponentConfig{
|
||||
Name: "comp",
|
||||
Type: "go_binary",
|
||||
Thresholds: map[string]float64{"error_rate": 5},
|
||||
ThresholdIsMax: map[string]bool{"error_rate": true},
|
||||
})
|
||||
|
||||
// Breach once → DEGRADED.
|
||||
hm.UpdateMetrics("comp", map[string]float64{"error_rate": 10})
|
||||
hm.checkHealth()
|
||||
|
||||
hm.mu.RLock()
|
||||
status := hm.components["comp"].Status
|
||||
hm.mu.RUnlock()
|
||||
if status != StatusDegraded {
|
||||
t.Errorf("expected DEGRADED after 1 breach, got %s", status)
|
||||
}
|
||||
|
||||
// Breach 3× → CRITICAL.
|
||||
for i := 0; i < 3; i++ {
|
||||
hm.checkHealth()
|
||||
}
|
||||
hm.mu.RLock()
|
||||
status = hm.components["comp"].Status
|
||||
hm.mu.RUnlock()
|
||||
if status != StatusCritical {
|
||||
t.Errorf("expected CRITICAL after repeated breaches, got %s", status)
|
||||
}
|
||||
}
|
||||
|
||||
// Test lower-bound threshold (ThresholdIsMax=false).
|
||||
func TestHealthMonitor_LowerBoundThreshold(t *testing.T) {
|
||||
hm := NewHealthMonitor(&mockCollector{}, 100)
|
||||
hm.RegisterComponent(ComponentConfig{
|
||||
Name: "immune",
|
||||
Type: "c_kernel_module",
|
||||
Thresholds: map[string]float64{"hooks_active": 10},
|
||||
ThresholdIsMax: map[string]bool{"hooks_active": false},
|
||||
})
|
||||
|
||||
// hooks_active = 5 (below threshold of 10) → warning.
|
||||
hm.UpdateMetrics("immune", map[string]float64{"hooks_active": 5})
|
||||
hm.checkHealth()
|
||||
|
||||
select {
|
||||
case alert := <-hm.alertBus:
|
||||
if alert.Component != "immune" || alert.Metric != "hooks_active" {
|
||||
t.Errorf("unexpected alert: %+v", alert)
|
||||
}
|
||||
default:
|
||||
t.Error("expected alert for hooks_active below threshold")
|
||||
}
|
||||
}
|
||||
|
||||
// Test ComponentCount.
|
||||
func TestHealthMonitor_ComponentCount(t *testing.T) {
|
||||
hm := NewHealthMonitor(&mockCollector{}, 10)
|
||||
if hm.ComponentCount() != 0 {
|
||||
t.Error("expected 0 initially")
|
||||
}
|
||||
registerTestComponents(hm, 4)
|
||||
if hm.ComponentCount() != 4 {
|
||||
t.Errorf("expected 4, got %d", hm.ComponentCount())
|
||||
}
|
||||
}
|
||||
|
||||
// Test Start/Stop lifecycle.
|
||||
func TestHealthMonitor_StartStop(t *testing.T) {
|
||||
hm := NewHealthMonitor(&mockCollector{}, 10)
|
||||
registerTestComponents(hm, 2)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
hm.Start(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Let it run briefly.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Clean shutdown.
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Start() did not return after context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func registerTestComponents(hm *HealthMonitor, n int) {
|
||||
for i := 0; i < n; i++ {
|
||||
hm.RegisterComponent(ComponentConfig{
|
||||
Name: fmt.Sprintf("comp-%d", i),
|
||||
Type: "go_binary",
|
||||
})
|
||||
}
|
||||
}
|
||||
247
internal/application/resilience/integrity.go
Normal file
247
internal/application/resilience/integrity.go
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
package resilience
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IntegrityStatus represents the result of an integrity check.
|
||||
type IntegrityStatus string
|
||||
|
||||
const (
|
||||
IntegrityVerified IntegrityStatus = "VERIFIED"
|
||||
IntegrityCompromised IntegrityStatus = "COMPROMISED"
|
||||
IntegrityUnknown IntegrityStatus = "UNKNOWN"
|
||||
)
|
||||
|
||||
// IntegrityReport is the full result of an integrity verification.
|
||||
type IntegrityReport struct {
|
||||
Overall IntegrityStatus `json:"overall"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Binaries map[string]BinaryStatus `json:"binaries,omitempty"`
|
||||
Chain *ChainStatus `json:"chain,omitempty"`
|
||||
Configs map[string]ConfigStatus `json:"configs,omitempty"`
|
||||
}
|
||||
|
||||
// BinaryStatus is the integrity status of a single binary.
|
||||
type BinaryStatus struct {
|
||||
Status IntegrityStatus `json:"status"`
|
||||
Expected string `json:"expected"`
|
||||
Current string `json:"current"`
|
||||
}
|
||||
|
||||
// ChainStatus is the integrity status of the decision chain.
|
||||
type ChainStatus struct {
|
||||
Valid bool `json:"valid"`
|
||||
Error string `json:"error,omitempty"`
|
||||
BreakPoint int `json:"break_point,omitempty"`
|
||||
Entries int `json:"entries"`
|
||||
}
|
||||
|
||||
// ConfigStatus is the integrity status of a config file.
|
||||
type ConfigStatus struct {
|
||||
Valid bool `json:"valid"`
|
||||
Error string `json:"error,omitempty"`
|
||||
StoredHMAC string `json:"stored_hmac,omitempty"`
|
||||
CurrentHMAC string `json:"current_hmac,omitempty"`
|
||||
}
|
||||
|
||||
// IntegrityVerifier performs periodic integrity checks on binaries,
|
||||
// decision chain, and config files.
|
||||
type IntegrityVerifier struct {
|
||||
mu sync.RWMutex
|
||||
binaryHashes map[string]string // path → expected SHA-256
|
||||
configPaths []string // config files to verify
|
||||
hmacKey []byte // key for config HMAC-SHA256
|
||||
chainPath string // path to decision chain log
|
||||
logger *slog.Logger
|
||||
lastReport *IntegrityReport
|
||||
}
|
||||
|
||||
// NewIntegrityVerifier creates a new integrity verifier.
|
||||
func NewIntegrityVerifier(hmacKey []byte) *IntegrityVerifier {
|
||||
return &IntegrityVerifier{
|
||||
binaryHashes: make(map[string]string),
|
||||
hmacKey: hmacKey,
|
||||
logger: slog.Default().With("component", "sarl-integrity"),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterBinary adds a binary with its expected SHA-256 hash.
|
||||
func (iv *IntegrityVerifier) RegisterBinary(path, expectedHash string) {
|
||||
iv.mu.Lock()
|
||||
defer iv.mu.Unlock()
|
||||
iv.binaryHashes[path] = expectedHash
|
||||
}
|
||||
|
||||
// RegisterConfig adds a config file to verify.
|
||||
func (iv *IntegrityVerifier) RegisterConfig(path string) {
|
||||
iv.mu.Lock()
|
||||
defer iv.mu.Unlock()
|
||||
iv.configPaths = append(iv.configPaths, path)
|
||||
}
|
||||
|
||||
// SetChainPath sets the decision chain log path.
|
||||
func (iv *IntegrityVerifier) SetChainPath(path string) {
|
||||
iv.mu.Lock()
|
||||
defer iv.mu.Unlock()
|
||||
iv.chainPath = path
|
||||
}
|
||||
|
||||
// VerifyAll runs all integrity checks and returns a comprehensive report.
|
||||
// Note: file I/O (binary hashing, config reading) is done WITHOUT holding
|
||||
// the mutex to prevent thread starvation on slow storage.
|
||||
func (iv *IntegrityVerifier) VerifyAll() IntegrityReport {
|
||||
report := IntegrityReport{
|
||||
Overall: IntegrityVerified,
|
||||
Timestamp: time.Now(),
|
||||
Binaries: make(map[string]BinaryStatus),
|
||||
Configs: make(map[string]ConfigStatus),
|
||||
}
|
||||
|
||||
// Snapshot config under lock, then release before I/O.
|
||||
iv.mu.RLock()
|
||||
binaryHashesCopy := make(map[string]string, len(iv.binaryHashes))
|
||||
for k, v := range iv.binaryHashes {
|
||||
binaryHashesCopy[k] = v
|
||||
}
|
||||
configPathsCopy := make([]string, len(iv.configPaths))
|
||||
copy(configPathsCopy, iv.configPaths)
|
||||
hmacKeyCopy := make([]byte, len(iv.hmacKey))
|
||||
copy(hmacKeyCopy, iv.hmacKey)
|
||||
chainPath := iv.chainPath
|
||||
iv.mu.RUnlock()
|
||||
|
||||
// Check binaries (file I/O — no lock held).
|
||||
for path, expected := range binaryHashesCopy {
|
||||
status := iv.verifyBinary(path, expected)
|
||||
report.Binaries[path] = status
|
||||
if status.Status == IntegrityCompromised {
|
||||
report.Overall = IntegrityCompromised
|
||||
}
|
||||
}
|
||||
|
||||
// Check configs (file I/O — no lock held).
|
||||
for _, path := range configPathsCopy {
|
||||
status := iv.verifyConfigFile(path)
|
||||
report.Configs[path] = status
|
||||
if !status.Valid {
|
||||
report.Overall = IntegrityCompromised
|
||||
}
|
||||
}
|
||||
|
||||
// Check decision chain (file I/O — no lock held).
|
||||
if chainPath != "" {
|
||||
chain := iv.verifyDecisionChain(chainPath)
|
||||
report.Chain = &chain
|
||||
if !chain.Valid {
|
||||
report.Overall = IntegrityCompromised
|
||||
}
|
||||
}
|
||||
|
||||
iv.mu.Lock()
|
||||
iv.lastReport = &report
|
||||
iv.mu.Unlock()
|
||||
|
||||
if report.Overall == IntegrityCompromised {
|
||||
iv.logger.Error("INTEGRITY COMPROMISED", "report", report)
|
||||
} else {
|
||||
iv.logger.Debug("integrity verified", "binaries", len(report.Binaries))
|
||||
}
|
||||
|
||||
return report
|
||||
}
|
||||
|
||||
// LastReport returns the most recent integrity report.
|
||||
func (iv *IntegrityVerifier) LastReport() *IntegrityReport {
|
||||
iv.mu.RLock()
|
||||
defer iv.mu.RUnlock()
|
||||
return iv.lastReport
|
||||
}
|
||||
|
||||
// verifyBinary calculates SHA-256 of a file and compares to expected.
|
||||
func (iv *IntegrityVerifier) verifyBinary(path, expected string) BinaryStatus {
|
||||
current, err := fileSHA256(path)
|
||||
if err != nil {
|
||||
return BinaryStatus{
|
||||
Status: IntegrityUnknown,
|
||||
Expected: expected,
|
||||
Current: fmt.Sprintf("error: %v", err),
|
||||
}
|
||||
}
|
||||
|
||||
if current != expected {
|
||||
return BinaryStatus{
|
||||
Status: IntegrityCompromised,
|
||||
Expected: expected,
|
||||
Current: current,
|
||||
}
|
||||
}
|
||||
|
||||
return BinaryStatus{
|
||||
Status: IntegrityVerified,
|
||||
Expected: expected,
|
||||
Current: current,
|
||||
}
|
||||
}
|
||||
|
||||
// verifyConfigFile checks HMAC-SHA256 of a config file.
|
||||
func (iv *IntegrityVerifier) verifyConfigFile(path string) ConfigStatus {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return ConfigStatus{Valid: false, Error: fmt.Sprintf("unreadable: %v", err)}
|
||||
}
|
||||
|
||||
currentHMAC := computeHMAC(data, iv.hmacKey)
|
||||
// For now, we just verify the file is readable and compute HMAC.
|
||||
// In production, the stored HMAC would be extracted from a sidecar file.
|
||||
return ConfigStatus{
|
||||
Valid: true,
|
||||
CurrentHMAC: currentHMAC,
|
||||
}
|
||||
}
|
||||
|
||||
// verifyDecisionChain verifies the SHA-256 hash chain in the decision log.
|
||||
func (iv *IntegrityVerifier) verifyDecisionChain(path string) ChainStatus {
|
||||
_, err := os.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return ChainStatus{Valid: true, Entries: 0} // No chain yet.
|
||||
}
|
||||
return ChainStatus{Valid: false, Error: fmt.Sprintf("unreadable: %v", err)}
|
||||
}
|
||||
|
||||
// In a real implementation, we'd parse the chain entries and verify
|
||||
// that each entry's hash includes the previous entry's hash.
|
||||
// For now, verify the file exists and is readable.
|
||||
return ChainStatus{Valid: true}
|
||||
}
|
||||
|
||||
// fileSHA256 computes the SHA-256 hash of a file.
|
||||
func fileSHA256(path string) (string, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(h.Sum(nil)), nil
|
||||
}
|
||||
|
||||
// computeHMAC computes HMAC-SHA256 of data with the given key.
|
||||
func computeHMAC(data, key []byte) string {
|
||||
mac := hmac.New(sha256.New, key)
|
||||
mac.Write(data)
|
||||
return hex.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
283
internal/application/resilience/metrics_collector.go
Normal file
283
internal/application/resilience/metrics_collector.go
Normal file
|
|
@ -0,0 +1,283 @@
|
|||
// Package resilience implements the Sentinel Autonomous Resilience Layer (SARL).
|
||||
//
|
||||
// Five levels of autonomous self-recovery:
|
||||
//
|
||||
// L1 — Self-Monitoring: health checks, quorum, anomaly detection
|
||||
// L2 — Self-Healing: restart, rollback, recovery strategies
|
||||
// L3 — Self-Preservation: emergency modes (safe/lockdown/apoptosis)
|
||||
// L4 — Immune Integration: behavioral anomaly detection
|
||||
// L5 — Autonomous Recovery: playbooks for resurrection, consensus, crypto
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MetricsDB provides an in-memory time-series store with ring buffers
|
||||
// for each component/metric pair. Supports rolling baselines (mean/stddev)
|
||||
// for Z-score anomaly detection.
|
||||
type MetricsDB struct {
|
||||
mu sync.RWMutex
|
||||
series map[string]*RingBuffer // key = "component:metric"
|
||||
window time.Duration // retention window (default 1h)
|
||||
maxSize int // max data points per series
|
||||
}
|
||||
|
||||
// DataPoint is a single timestamped metric value.
|
||||
type DataPoint struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Value float64 `json:"value"`
|
||||
}
|
||||
|
||||
// Baseline holds rolling statistics for anomaly detection.
|
||||
type Baseline struct {
|
||||
Mean float64 `json:"mean"`
|
||||
StdDev float64 `json:"std_dev"`
|
||||
Count int `json:"count"`
|
||||
Min float64 `json:"min"`
|
||||
Max float64 `json:"max"`
|
||||
}
|
||||
|
||||
// RingBuffer is a fixed-size circular buffer for DataPoints.
|
||||
type RingBuffer struct {
|
||||
data []DataPoint
|
||||
head int
|
||||
count int
|
||||
size int
|
||||
}
|
||||
|
||||
// DefaultMetricsWindow is the default retention window (1 hour).
|
||||
const DefaultMetricsWindow = 1 * time.Hour
|
||||
|
||||
// DefaultMetricsMaxSize is the default max points per series (1h / 10s = 360).
|
||||
const DefaultMetricsMaxSize = 360
|
||||
|
||||
// NewMetricsDB creates a new in-memory time-series store.
|
||||
func NewMetricsDB(window time.Duration, maxSize int) *MetricsDB {
|
||||
if window <= 0 {
|
||||
window = DefaultMetricsWindow
|
||||
}
|
||||
if maxSize <= 0 {
|
||||
maxSize = DefaultMetricsMaxSize
|
||||
}
|
||||
return &MetricsDB{
|
||||
series: make(map[string]*RingBuffer),
|
||||
window: window,
|
||||
maxSize: maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
// AddDataPoint records a metric value for a component.
|
||||
func (db *MetricsDB) AddDataPoint(component, metric string, value float64) {
|
||||
key := component + ":" + metric
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
rb, ok := db.series[key]
|
||||
if !ok {
|
||||
rb = newRingBuffer(db.maxSize)
|
||||
db.series[key] = rb
|
||||
}
|
||||
rb.Add(DataPoint{Timestamp: time.Now(), Value: value})
|
||||
}
|
||||
|
||||
// GetBaseline returns rolling mean/stddev for a component metric
|
||||
// calculated over the specified window duration.
|
||||
func (db *MetricsDB) GetBaseline(component, metric string, window time.Duration) Baseline {
|
||||
key := component + ":" + metric
|
||||
db.mu.RLock()
|
||||
defer db.mu.RUnlock()
|
||||
|
||||
rb, ok := db.series[key]
|
||||
if !ok {
|
||||
return Baseline{}
|
||||
}
|
||||
|
||||
cutoff := time.Now().Add(-window)
|
||||
points := rb.After(cutoff)
|
||||
|
||||
if len(points) == 0 {
|
||||
return Baseline{}
|
||||
}
|
||||
|
||||
return calculateBaseline(points)
|
||||
}
|
||||
|
||||
// GetRecent returns the most recent N data points for a component metric.
|
||||
func (db *MetricsDB) GetRecent(component, metric string, n int) []DataPoint {
|
||||
key := component + ":" + metric
|
||||
db.mu.RLock()
|
||||
defer db.mu.RUnlock()
|
||||
|
||||
rb, ok := db.series[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
all := rb.All()
|
||||
if len(all) <= n {
|
||||
return all
|
||||
}
|
||||
return all[len(all)-n:]
|
||||
}
|
||||
|
||||
// CalculateZScore returns the Z-score for a value against the baseline.
|
||||
// Returns 0 if baseline has insufficient data or zero stddev.
|
||||
func CalculateZScore(value float64, baseline Baseline) float64 {
|
||||
if baseline.Count < 10 || baseline.StdDev == 0 {
|
||||
return 0
|
||||
}
|
||||
return (value - baseline.Mean) / baseline.StdDev
|
||||
}
|
||||
|
||||
// IsAnomaly returns true if the Z-score exceeds the threshold (default 3.0).
|
||||
func IsAnomaly(value float64, baseline Baseline, threshold float64) bool {
|
||||
if threshold <= 0 {
|
||||
threshold = 3.0
|
||||
}
|
||||
zscore := CalculateZScore(value, baseline)
|
||||
return math.Abs(zscore) > threshold
|
||||
}
|
||||
|
||||
// SeriesCount returns the number of tracked series.
|
||||
func (db *MetricsDB) SeriesCount() int {
|
||||
db.mu.RLock()
|
||||
defer db.mu.RUnlock()
|
||||
return len(db.series)
|
||||
}
|
||||
|
||||
// Purge removes data points older than the retention window.
|
||||
func (db *MetricsDB) Purge() int {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
cutoff := time.Now().Add(-db.window)
|
||||
total := 0
|
||||
for key, rb := range db.series {
|
||||
removed := rb.RemoveBefore(cutoff)
|
||||
total += removed
|
||||
if rb.Len() == 0 {
|
||||
delete(db.series, key)
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// --- RingBuffer implementation ---
|
||||
|
||||
func newRingBuffer(size int) *RingBuffer {
|
||||
return &RingBuffer{
|
||||
data: make([]DataPoint, size),
|
||||
size: size,
|
||||
}
|
||||
}
|
||||
|
||||
// Add inserts a DataPoint, overwriting the oldest if full.
|
||||
func (rb *RingBuffer) Add(dp DataPoint) {
|
||||
rb.data[rb.head] = dp
|
||||
rb.head = (rb.head + 1) % rb.size
|
||||
if rb.count < rb.size {
|
||||
rb.count++
|
||||
}
|
||||
}
|
||||
|
||||
// Len returns the number of data points in the buffer.
|
||||
func (rb *RingBuffer) Len() int {
|
||||
return rb.count
|
||||
}
|
||||
|
||||
// All returns all data points in chronological order.
|
||||
func (rb *RingBuffer) All() []DataPoint {
|
||||
if rb.count == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make([]DataPoint, rb.count)
|
||||
if rb.count < rb.size {
|
||||
// Buffer not yet full — data starts at 0.
|
||||
copy(result, rb.data[:rb.count])
|
||||
} else {
|
||||
// Buffer wrapped — oldest is at head.
|
||||
n := copy(result, rb.data[rb.head:rb.size])
|
||||
copy(result[n:], rb.data[:rb.head])
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// After returns points with timestamp after the cutoff.
|
||||
func (rb *RingBuffer) After(cutoff time.Time) []DataPoint {
|
||||
all := rb.All()
|
||||
result := make([]DataPoint, 0, len(all))
|
||||
for _, dp := range all {
|
||||
if dp.Timestamp.After(cutoff) {
|
||||
result = append(result, dp)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// RemoveBefore removes data points before the cutoff by compacting.
|
||||
// Returns the number of points removed.
|
||||
func (rb *RingBuffer) RemoveBefore(cutoff time.Time) int {
|
||||
all := rb.All()
|
||||
kept := make([]DataPoint, 0, len(all))
|
||||
for _, dp := range all {
|
||||
if !dp.Timestamp.Before(cutoff) {
|
||||
kept = append(kept, dp)
|
||||
}
|
||||
}
|
||||
|
||||
removed := len(all) - len(kept)
|
||||
if removed == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Rebuild the ring buffer with kept data.
|
||||
rb.count = 0
|
||||
rb.head = 0
|
||||
for _, dp := range kept {
|
||||
rb.Add(dp)
|
||||
}
|
||||
return removed
|
||||
}
|
||||
|
||||
// --- Statistics ---
|
||||
|
||||
func calculateBaseline(points []DataPoint) Baseline {
|
||||
n := len(points)
|
||||
if n == 0 {
|
||||
return Baseline{}
|
||||
}
|
||||
|
||||
var sum, min, max float64
|
||||
min = points[0].Value
|
||||
max = points[0].Value
|
||||
|
||||
for _, p := range points {
|
||||
sum += p.Value
|
||||
if p.Value < min {
|
||||
min = p.Value
|
||||
}
|
||||
if p.Value > max {
|
||||
max = p.Value
|
||||
}
|
||||
}
|
||||
mean := sum / float64(n)
|
||||
|
||||
var variance float64
|
||||
for _, p := range points {
|
||||
diff := p.Value - mean
|
||||
variance += diff * diff
|
||||
}
|
||||
variance /= float64(n)
|
||||
|
||||
return Baseline{
|
||||
Mean: mean,
|
||||
StdDev: math.Sqrt(variance),
|
||||
Count: n,
|
||||
Min: min,
|
||||
Max: max,
|
||||
}
|
||||
}
|
||||
290
internal/application/resilience/preservation.go
Normal file
290
internal/application/resilience/preservation.go
Normal file
|
|
@ -0,0 +1,290 @@
|
|||
package resilience
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// EmergencyMode defines the system's emergency state.
|
||||
type EmergencyMode string
|
||||
|
||||
const (
|
||||
ModeNone EmergencyMode = "NONE"
|
||||
ModeSafe EmergencyMode = "SAFE"
|
||||
ModeLockdown EmergencyMode = "LOCKDOWN"
|
||||
ModeApoptosis EmergencyMode = "APOPTOSIS"
|
||||
)
|
||||
|
||||
// ModeActivation records when and why a mode was activated.
|
||||
type ModeActivation struct {
|
||||
Mode EmergencyMode `json:"mode"`
|
||||
ActivatedAt time.Time `json:"activated_at"`
|
||||
ActivatedBy string `json:"activated_by"` // "auto" or "architect:<name>"
|
||||
Reason string `json:"reason"`
|
||||
AutoExit bool `json:"auto_exit"`
|
||||
AutoExitAt time.Time `json:"auto_exit_at,omitempty"`
|
||||
}
|
||||
|
||||
// PreservationEvent is an audit log entry for preservation actions.
|
||||
type PreservationEvent struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Mode EmergencyMode `json:"mode"`
|
||||
Action string `json:"action"`
|
||||
Detail string `json:"detail"`
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ModeActionFunc is a callback to perform mode-specific actions.
|
||||
// Implementations handle the real system operations (network isolation, process freeze, etc.).
|
||||
type ModeActionFunc func(mode EmergencyMode, action string, params map[string]interface{}) error
|
||||
|
||||
// PreservationEngine manages emergency modes (safe/lockdown/apoptosis).
|
||||
type PreservationEngine struct {
|
||||
mu sync.RWMutex
|
||||
currentMode EmergencyMode
|
||||
activation *ModeActivation
|
||||
history []PreservationEvent
|
||||
actionFn ModeActionFunc
|
||||
integrityFn func() IntegrityReport // pluggable integrity check
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewPreservationEngine creates a new preservation engine.
|
||||
func NewPreservationEngine(actionFn ModeActionFunc) *PreservationEngine {
|
||||
return &PreservationEngine{
|
||||
currentMode: ModeNone,
|
||||
history: make([]PreservationEvent, 0),
|
||||
actionFn: actionFn,
|
||||
logger: slog.Default().With("component", "sarl-preservation"),
|
||||
}
|
||||
}
|
||||
|
||||
// CurrentMode returns the active emergency mode.
|
||||
func (pe *PreservationEngine) CurrentMode() EmergencyMode {
|
||||
pe.mu.RLock()
|
||||
defer pe.mu.RUnlock()
|
||||
return pe.currentMode
|
||||
}
|
||||
|
||||
// Activation returns the current mode activation details (nil if NONE).
|
||||
func (pe *PreservationEngine) Activation() *ModeActivation {
|
||||
pe.mu.RLock()
|
||||
defer pe.mu.RUnlock()
|
||||
if pe.activation == nil {
|
||||
return nil
|
||||
}
|
||||
cp := *pe.activation
|
||||
return &cp
|
||||
}
|
||||
|
||||
// ActivateMode enters an emergency mode. Returns error if transition is invalid.
|
||||
func (pe *PreservationEngine) ActivateMode(mode EmergencyMode, reason, activatedBy string) error {
|
||||
pe.mu.Lock()
|
||||
defer pe.mu.Unlock()
|
||||
|
||||
if mode == ModeNone {
|
||||
return fmt.Errorf("use DeactivateMode to exit emergency mode")
|
||||
}
|
||||
|
||||
// Validate transitions: can always escalate, can't downgrade.
|
||||
if !pe.isValidTransition(pe.currentMode, mode) {
|
||||
return fmt.Errorf("invalid transition: %s → %s", pe.currentMode, mode)
|
||||
}
|
||||
|
||||
pe.logger.Warn("EMERGENCY MODE ACTIVATION",
|
||||
"mode", mode,
|
||||
"reason", reason,
|
||||
"activated_by", activatedBy,
|
||||
)
|
||||
|
||||
// Execute mode-specific actions.
|
||||
actions := pe.actionsForMode(mode)
|
||||
for _, action := range actions {
|
||||
err := pe.executeAction(mode, action.name, action.params)
|
||||
if err != nil {
|
||||
pe.logger.Error("mode action failed",
|
||||
"mode", mode,
|
||||
"action", action.name,
|
||||
"error", err,
|
||||
)
|
||||
// In critical modes, continue despite errors.
|
||||
if mode != ModeApoptosis {
|
||||
return fmt.Errorf("failed to activate %s: action %s: %w", mode, action.name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
activation := &ModeActivation{
|
||||
Mode: mode,
|
||||
ActivatedAt: time.Now(),
|
||||
ActivatedBy: activatedBy,
|
||||
Reason: reason,
|
||||
}
|
||||
|
||||
if mode == ModeSafe {
|
||||
activation.AutoExit = true
|
||||
activation.AutoExitAt = time.Now().Add(15 * time.Minute)
|
||||
}
|
||||
|
||||
pe.currentMode = mode
|
||||
pe.activation = activation
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeactivateMode exits the current emergency mode and returns to NONE.
|
||||
func (pe *PreservationEngine) DeactivateMode(deactivatedBy string) error {
|
||||
pe.mu.Lock()
|
||||
defer pe.mu.Unlock()
|
||||
|
||||
if pe.currentMode == ModeNone {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Lockdown and apoptosis require manual deactivation by architect.
|
||||
if pe.currentMode == ModeApoptosis {
|
||||
return fmt.Errorf("apoptosis mode cannot be deactivated — system rebuild required")
|
||||
}
|
||||
|
||||
pe.logger.Info("EMERGENCY MODE DEACTIVATION",
|
||||
"mode", pe.currentMode,
|
||||
"deactivated_by", deactivatedBy,
|
||||
)
|
||||
|
||||
pe.recordEvent(pe.currentMode, "deactivated",
|
||||
fmt.Sprintf("deactivated by %s", deactivatedBy), true, "")
|
||||
|
||||
pe.currentMode = ModeNone
|
||||
pe.activation = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ShouldAutoExit checks if safe mode should auto-exit based on timer.
|
||||
func (pe *PreservationEngine) ShouldAutoExit() bool {
|
||||
pe.mu.RLock()
|
||||
defer pe.mu.RUnlock()
|
||||
|
||||
if pe.currentMode != ModeSafe || pe.activation == nil {
|
||||
return false
|
||||
}
|
||||
return pe.activation.AutoExit && time.Now().After(pe.activation.AutoExitAt)
|
||||
}
|
||||
|
||||
// isValidTransition checks if a mode transition is allowed.
|
||||
// Escalation order: NONE → SAFE → LOCKDOWN → APOPTOSIS.
|
||||
func (pe *PreservationEngine) isValidTransition(from, to EmergencyMode) bool {
|
||||
rank := map[EmergencyMode]int{
|
||||
ModeNone: 0,
|
||||
ModeSafe: 1,
|
||||
ModeLockdown: 2,
|
||||
ModeApoptosis: 3,
|
||||
}
|
||||
// Can always escalate or re-enter same mode.
|
||||
return rank[to] >= rank[from]
|
||||
}
|
||||
|
||||
type modeAction struct {
|
||||
name string
|
||||
params map[string]interface{}
|
||||
}
|
||||
|
||||
// actionsForMode returns the actions to execute for a given mode.
|
||||
func (pe *PreservationEngine) actionsForMode(mode EmergencyMode) []modeAction {
|
||||
switch mode {
|
||||
case ModeSafe:
|
||||
return []modeAction{
|
||||
{"disable_non_essential_services", map[string]interface{}{
|
||||
"services": []string{"analytics", "reporting", "p2p_sync", "threat_intel_feeds"},
|
||||
}},
|
||||
{"enable_readonly_mode", map[string]interface{}{
|
||||
"scope": []string{"event_ingest", "correlation", "dashboard_view"},
|
||||
}},
|
||||
{"preserve_all_logs", nil},
|
||||
{"notify_architect", map[string]interface{}{"severity": "emergency"}},
|
||||
{"increase_monitoring_frequency", map[string]interface{}{"interval": "5s"}},
|
||||
}
|
||||
case ModeLockdown:
|
||||
return []modeAction{
|
||||
{"isolate_from_network", map[string]interface{}{"scope": "all_external"}},
|
||||
{"freeze_all_processes", nil},
|
||||
{"capture_memory_dump", nil},
|
||||
{"capture_disk_snapshot", nil},
|
||||
{"trigger_immune_kernel_lock", map[string]interface{}{
|
||||
"allow_syscalls": []string{"read", "write", "exit"},
|
||||
}},
|
||||
{"send_panic_alert", map[string]interface{}{
|
||||
"channels": []string{"email", "sms", "slack", "pagerduty"},
|
||||
}},
|
||||
}
|
||||
case ModeApoptosis:
|
||||
return []modeAction{
|
||||
{"graceful_shutdown", map[string]interface{}{"timeout": "30s", "drain_events": true}},
|
||||
{"zero_sensitive_memory", map[string]interface{}{
|
||||
"regions": []string{"keys", "certs", "tokens", "secrets"},
|
||||
}},
|
||||
{"preserve_forensic_evidence", nil},
|
||||
{"notify_soc", map[string]interface{}{
|
||||
"severity": "CRITICAL",
|
||||
"message": "system self-terminated",
|
||||
}},
|
||||
{"secure_erase_temp_files", nil},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeAction runs a mode action and records the result.
|
||||
func (pe *PreservationEngine) executeAction(mode EmergencyMode, name string, params map[string]interface{}) error {
|
||||
err := pe.actionFn(mode, name, params)
|
||||
success := err == nil
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
pe.recordEvent(mode, name, fmt.Sprintf("params: %v", params), success, errStr)
|
||||
return err
|
||||
}
|
||||
|
||||
// recordEvent appends to the audit history.
|
||||
func (pe *PreservationEngine) recordEvent(mode EmergencyMode, action, detail string, success bool, errStr string) {
|
||||
pe.history = append(pe.history, PreservationEvent{
|
||||
Timestamp: time.Now(),
|
||||
Mode: mode,
|
||||
Action: action,
|
||||
Detail: detail,
|
||||
Success: success,
|
||||
Error: errStr,
|
||||
})
|
||||
}
|
||||
|
||||
// History returns the preservation audit log.
|
||||
func (pe *PreservationEngine) History() []PreservationEvent {
|
||||
pe.mu.RLock()
|
||||
defer pe.mu.RUnlock()
|
||||
result := make([]PreservationEvent, len(pe.history))
|
||||
copy(result, pe.history)
|
||||
return result
|
||||
}
|
||||
|
||||
// SetIntegrityCheck sets the pluggable integrity checker.
|
||||
func (pe *PreservationEngine) SetIntegrityCheck(fn func() IntegrityReport) {
|
||||
pe.mu.Lock()
|
||||
defer pe.mu.Unlock()
|
||||
pe.integrityFn = fn
|
||||
}
|
||||
|
||||
// CheckIntegrity runs the pluggable integrity check and returns the report.
|
||||
func (pe *PreservationEngine) CheckIntegrity() IntegrityReport {
|
||||
pe.mu.RLock()
|
||||
fn := pe.integrityFn
|
||||
pe.mu.RUnlock()
|
||||
|
||||
if fn == nil {
|
||||
return IntegrityReport{Overall: IntegrityVerified, Timestamp: time.Now()}
|
||||
}
|
||||
return fn()
|
||||
}
|
||||
439
internal/application/resilience/preservation_test.go
Normal file
439
internal/application/resilience/preservation_test.go
Normal file
|
|
@ -0,0 +1,439 @@
|
|||
package resilience
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- Mock action function ---
|
||||
|
||||
type modeActionLog struct {
|
||||
calls []struct {
|
||||
mode EmergencyMode
|
||||
action string
|
||||
}
|
||||
failAction string // if set, this action will fail
|
||||
}
|
||||
|
||||
func newModeActionLog() *modeActionLog {
|
||||
return &modeActionLog{}
|
||||
}
|
||||
|
||||
func (m *modeActionLog) execute(mode EmergencyMode, action string, _ map[string]interface{}) error {
|
||||
m.calls = append(m.calls, struct {
|
||||
mode EmergencyMode
|
||||
action string
|
||||
}{mode, action})
|
||||
if m.failAction == action {
|
||||
return errActionFailed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var errActionFailed = &actionError{"simulated failure"}
|
||||
|
||||
type actionError struct{ msg string }
|
||||
|
||||
func (e *actionError) Error() string { return e.msg }
|
||||
|
||||
// --- Preservation Engine Tests ---
|
||||
|
||||
// SP-01: Safe mode activation.
|
||||
func TestPreservation_SP01_SafeMode(t *testing.T) {
|
||||
log := newModeActionLog()
|
||||
pe := NewPreservationEngine(log.execute)
|
||||
|
||||
err := pe.ActivateMode(ModeSafe, "quorum lost (3/6 offline)", "auto")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if pe.CurrentMode() != ModeSafe {
|
||||
t.Errorf("expected SAFE, got %s", pe.CurrentMode())
|
||||
}
|
||||
|
||||
activation := pe.Activation()
|
||||
if activation == nil {
|
||||
t.Fatal("expected activation details")
|
||||
}
|
||||
if !activation.AutoExit {
|
||||
t.Error("safe mode should have auto-exit enabled")
|
||||
}
|
||||
|
||||
// Should have executed safe mode actions.
|
||||
if len(log.calls) == 0 {
|
||||
t.Error("expected mode actions to be executed")
|
||||
}
|
||||
// First action should be disable_non_essential_services.
|
||||
if log.calls[0].action != "disable_non_essential_services" {
|
||||
t.Errorf("expected first action disable_non_essential_services, got %s", log.calls[0].action)
|
||||
}
|
||||
}
|
||||
|
||||
// SP-02: Lockdown mode activation.
|
||||
func TestPreservation_SP02_LockdownMode(t *testing.T) {
|
||||
log := newModeActionLog()
|
||||
pe := NewPreservationEngine(log.execute)
|
||||
|
||||
err := pe.ActivateMode(ModeLockdown, "binary tampering detected", "auto")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if pe.CurrentMode() != ModeLockdown {
|
||||
t.Errorf("expected LOCKDOWN, got %s", pe.CurrentMode())
|
||||
}
|
||||
|
||||
// Should have network isolation action.
|
||||
foundIsolate := false
|
||||
for _, c := range log.calls {
|
||||
if c.action == "isolate_from_network" {
|
||||
foundIsolate = true
|
||||
}
|
||||
}
|
||||
if !foundIsolate {
|
||||
t.Error("expected isolate_from_network in lockdown actions")
|
||||
}
|
||||
}
|
||||
|
||||
// SP-03: Apoptosis mode activation.
|
||||
func TestPreservation_SP03_ApoptosisMode(t *testing.T) {
|
||||
log := newModeActionLog()
|
||||
pe := NewPreservationEngine(log.execute)
|
||||
|
||||
err := pe.ActivateMode(ModeApoptosis, "rootkit detected", "architect:admin")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if pe.CurrentMode() != ModeApoptosis {
|
||||
t.Errorf("expected APOPTOSIS, got %s", pe.CurrentMode())
|
||||
}
|
||||
|
||||
// Should have graceful_shutdown action.
|
||||
foundShutdown := false
|
||||
for _, c := range log.calls {
|
||||
if c.action == "graceful_shutdown" {
|
||||
foundShutdown = true
|
||||
}
|
||||
}
|
||||
if !foundShutdown {
|
||||
t.Error("expected graceful_shutdown in apoptosis actions")
|
||||
}
|
||||
|
||||
// Cannot deactivate apoptosis.
|
||||
err = pe.DeactivateMode("architect:admin")
|
||||
if err == nil {
|
||||
t.Error("expected error deactivating apoptosis")
|
||||
}
|
||||
}
|
||||
|
||||
// SP-04: Invalid transition (downgrade).
|
||||
func TestPreservation_SP04_InvalidTransition(t *testing.T) {
|
||||
log := newModeActionLog()
|
||||
pe := NewPreservationEngine(log.execute)
|
||||
|
||||
pe.ActivateMode(ModeLockdown, "test", "auto")
|
||||
|
||||
// Can't downgrade from LOCKDOWN to SAFE.
|
||||
err := pe.ActivateMode(ModeSafe, "test downgrade", "auto")
|
||||
if err == nil {
|
||||
t.Error("expected error on downgrade from LOCKDOWN to SAFE")
|
||||
}
|
||||
}
|
||||
|
||||
// SP-05: Escalation (SAFE → LOCKDOWN → APOPTOSIS).
|
||||
func TestPreservation_SP05_Escalation(t *testing.T) {
|
||||
log := newModeActionLog()
|
||||
pe := NewPreservationEngine(log.execute)
|
||||
|
||||
pe.ActivateMode(ModeSafe, "quorum lost", "auto")
|
||||
if pe.CurrentMode() != ModeSafe {
|
||||
t.Fatal("expected SAFE")
|
||||
}
|
||||
|
||||
pe.ActivateMode(ModeLockdown, "compromise detected", "auto")
|
||||
if pe.CurrentMode() != ModeLockdown {
|
||||
t.Fatal("expected LOCKDOWN")
|
||||
}
|
||||
|
||||
pe.ActivateMode(ModeApoptosis, "rootkit", "auto")
|
||||
if pe.CurrentMode() != ModeApoptosis {
|
||||
t.Fatal("expected APOPTOSIS")
|
||||
}
|
||||
}
|
||||
|
||||
// SP-06: Safe mode auto-exit.
|
||||
func TestPreservation_SP06_AutoExit(t *testing.T) {
|
||||
log := newModeActionLog()
|
||||
pe := NewPreservationEngine(log.execute)
|
||||
|
||||
pe.ActivateMode(ModeSafe, "test", "auto")
|
||||
|
||||
// Not yet time.
|
||||
if pe.ShouldAutoExit() {
|
||||
t.Error("should not auto-exit immediately")
|
||||
}
|
||||
|
||||
// Fast-forward activation's auto_exit_at.
|
||||
pe.mu.Lock()
|
||||
pe.activation.AutoExitAt = time.Now().Add(-1 * time.Second)
|
||||
pe.mu.Unlock()
|
||||
|
||||
if !pe.ShouldAutoExit() {
|
||||
t.Error("should auto-exit after timer expired")
|
||||
}
|
||||
}
|
||||
|
||||
// SP-07: Manual deactivation of safe mode.
|
||||
func TestPreservation_SP07_ManualDeactivate(t *testing.T) {
|
||||
log := newModeActionLog()
|
||||
pe := NewPreservationEngine(log.execute)
|
||||
|
||||
pe.ActivateMode(ModeSafe, "test", "auto")
|
||||
err := pe.DeactivateMode("architect:admin")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if pe.CurrentMode() != ModeNone {
|
||||
t.Errorf("expected NONE, got %s", pe.CurrentMode())
|
||||
}
|
||||
}
|
||||
|
||||
// SP-08: Lockdown deactivation.
|
||||
func TestPreservation_SP08_LockdownDeactivate(t *testing.T) {
|
||||
log := newModeActionLog()
|
||||
pe := NewPreservationEngine(log.execute)
|
||||
|
||||
pe.ActivateMode(ModeLockdown, "test", "auto")
|
||||
err := pe.DeactivateMode("architect:admin")
|
||||
if err != nil {
|
||||
t.Fatalf("lockdown deactivation should succeed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// SP-09: History audit log.
|
||||
func TestPreservation_SP09_AuditHistory(t *testing.T) {
|
||||
log := newModeActionLog()
|
||||
pe := NewPreservationEngine(log.execute)
|
||||
|
||||
pe.ActivateMode(ModeSafe, "test", "auto")
|
||||
pe.DeactivateMode("admin")
|
||||
|
||||
history := pe.History()
|
||||
if len(history) == 0 {
|
||||
t.Error("expected audit history entries")
|
||||
}
|
||||
|
||||
// Last entry should be deactivation.
|
||||
last := history[len(history)-1]
|
||||
if last.Action != "deactivated" {
|
||||
t.Errorf("expected deactivated, got %s", last.Action)
|
||||
}
|
||||
}
|
||||
|
||||
// SP-10: Action failure in non-apoptosis mode aborts.
|
||||
func TestPreservation_SP10_ActionFailure(t *testing.T) {
|
||||
log := newModeActionLog()
|
||||
log.failAction = "disable_non_essential_services"
|
||||
pe := NewPreservationEngine(log.execute)
|
||||
|
||||
err := pe.ActivateMode(ModeSafe, "test", "auto")
|
||||
if err == nil {
|
||||
t.Error("expected error when safe mode action fails")
|
||||
}
|
||||
// Mode should not have changed due to failure.
|
||||
if pe.CurrentMode() != ModeNone {
|
||||
t.Errorf("expected NONE after failed activation, got %s", pe.CurrentMode())
|
||||
}
|
||||
}
|
||||
|
||||
// SP-10b: Action failure in apoptosis mode continues.
|
||||
func TestPreservation_SP10b_ApoptosisActionFailure(t *testing.T) {
|
||||
log := newModeActionLog()
|
||||
log.failAction = "graceful_shutdown"
|
||||
pe := NewPreservationEngine(log.execute)
|
||||
|
||||
// Apoptosis should continue despite action failures.
|
||||
err := pe.ActivateMode(ModeApoptosis, "rootkit", "auto")
|
||||
if err != nil {
|
||||
t.Fatalf("apoptosis should not fail on action errors: %v", err)
|
||||
}
|
||||
if pe.CurrentMode() != ModeApoptosis {
|
||||
t.Errorf("expected APOPTOSIS, got %s", pe.CurrentMode())
|
||||
}
|
||||
}
|
||||
|
||||
// Test ModeNone activation rejected.
|
||||
func TestPreservation_ModeNoneRejected(t *testing.T) {
|
||||
pe := NewPreservationEngine(func(_ EmergencyMode, _ string, _ map[string]interface{}) error { return nil })
|
||||
err := pe.ActivateMode(ModeNone, "test", "auto")
|
||||
if err == nil {
|
||||
t.Error("expected error activating ModeNone")
|
||||
}
|
||||
}
|
||||
|
||||
// Test deactivate when already NONE.
|
||||
func TestPreservation_DeactivateNone(t *testing.T) {
|
||||
pe := NewPreservationEngine(func(_ EmergencyMode, _ string, _ map[string]interface{}) error { return nil })
|
||||
err := pe.DeactivateMode("admin")
|
||||
if err != nil {
|
||||
t.Errorf("deactivating NONE should be no-op: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test ShouldAutoExit when not in safe mode.
|
||||
func TestPreservation_AutoExitNotSafe(t *testing.T) {
|
||||
pe := NewPreservationEngine(func(_ EmergencyMode, _ string, _ map[string]interface{}) error { return nil })
|
||||
if pe.ShouldAutoExit() {
|
||||
t.Error("should not auto-exit when mode is NONE")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Integrity Verifier Tests ---
|
||||
|
||||
// SP-04 (ТЗ): Binary integrity check — hash mismatch.
|
||||
func TestIntegrity_BinaryMismatch(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
binPath := filepath.Join(tmpDir, "test-binary")
|
||||
os.WriteFile(binPath, []byte("original content"), 0o644)
|
||||
|
||||
// Calculate correct hash.
|
||||
h := sha256.Sum256([]byte("original content"))
|
||||
correctHash := hex.EncodeToString(h[:])
|
||||
|
||||
iv := NewIntegrityVerifier([]byte("test-key"))
|
||||
iv.RegisterBinary(binPath, correctHash)
|
||||
|
||||
// Verify (should pass).
|
||||
report := iv.VerifyAll()
|
||||
if report.Overall != IntegrityVerified {
|
||||
t.Errorf("expected VERIFIED, got %s", report.Overall)
|
||||
}
|
||||
|
||||
// Tamper with the binary.
|
||||
os.WriteFile(binPath, []byte("tampered content"), 0o644)
|
||||
|
||||
// Verify (should fail).
|
||||
report = iv.VerifyAll()
|
||||
if report.Overall != IntegrityCompromised {
|
||||
t.Errorf("expected COMPROMISED, got %s", report.Overall)
|
||||
}
|
||||
bs := report.Binaries[binPath]
|
||||
if bs.Status != IntegrityCompromised {
|
||||
t.Errorf("expected binary COMPROMISED, got %s", bs.Status)
|
||||
}
|
||||
}
|
||||
|
||||
// Binary not found.
|
||||
func TestIntegrity_BinaryNotFound(t *testing.T) {
|
||||
iv := NewIntegrityVerifier([]byte("test-key"))
|
||||
iv.RegisterBinary("/nonexistent/binary", "abc123")
|
||||
|
||||
report := iv.VerifyAll()
|
||||
bs := report.Binaries["/nonexistent/binary"]
|
||||
if bs.Status != IntegrityUnknown {
|
||||
t.Errorf("expected UNKNOWN for missing binary, got %s", bs.Status)
|
||||
}
|
||||
}
|
||||
|
||||
// Config HMAC computation.
|
||||
func TestIntegrity_ConfigHMAC(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfgPath := filepath.Join(tmpDir, "config.yaml")
|
||||
os.WriteFile(cfgPath, []byte("server:\n port: 8080"), 0o644)
|
||||
|
||||
iv := NewIntegrityVerifier([]byte("hmac-key"))
|
||||
iv.RegisterConfig(cfgPath)
|
||||
|
||||
report := iv.VerifyAll()
|
||||
cs := report.Configs[cfgPath]
|
||||
if !cs.Valid {
|
||||
t.Errorf("expected valid config, got error: %s", cs.Error)
|
||||
}
|
||||
if cs.CurrentHMAC == "" {
|
||||
t.Error("expected non-empty HMAC")
|
||||
}
|
||||
}
|
||||
|
||||
// Config file unreadable.
|
||||
func TestIntegrity_ConfigUnreadable(t *testing.T) {
|
||||
iv := NewIntegrityVerifier([]byte("key"))
|
||||
iv.RegisterConfig("/nonexistent/config.yaml")
|
||||
|
||||
report := iv.VerifyAll()
|
||||
cs := report.Configs["/nonexistent/config.yaml"]
|
||||
if cs.Valid {
|
||||
t.Error("expected invalid for unreadable config")
|
||||
}
|
||||
}
|
||||
|
||||
// Decision chain — file does not exist (OK, no chain yet).
|
||||
func TestIntegrity_ChainNotExist(t *testing.T) {
|
||||
iv := NewIntegrityVerifier([]byte("key"))
|
||||
iv.SetChainPath("/nonexistent/decisions.log")
|
||||
|
||||
report := iv.VerifyAll()
|
||||
if report.Chain == nil {
|
||||
t.Fatal("expected chain status")
|
||||
}
|
||||
if !report.Chain.Valid {
|
||||
t.Error("nonexistent chain should be valid (no entries)")
|
||||
}
|
||||
}
|
||||
|
||||
// Decision chain — file exists.
|
||||
func TestIntegrity_ChainExists(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
chainPath := filepath.Join(tmpDir, "decisions.log")
|
||||
os.WriteFile(chainPath, []byte("entry1\nentry2\n"), 0o644)
|
||||
|
||||
iv := NewIntegrityVerifier([]byte("key"))
|
||||
iv.SetChainPath(chainPath)
|
||||
|
||||
report := iv.VerifyAll()
|
||||
if report.Chain == nil {
|
||||
t.Fatal("expected chain status")
|
||||
}
|
||||
if !report.Chain.Valid {
|
||||
t.Error("expected valid chain")
|
||||
}
|
||||
}
|
||||
|
||||
// LastReport.
|
||||
func TestIntegrity_LastReport(t *testing.T) {
|
||||
iv := NewIntegrityVerifier([]byte("key"))
|
||||
if iv.LastReport() != nil {
|
||||
t.Error("expected nil before first verify")
|
||||
}
|
||||
|
||||
iv.VerifyAll()
|
||||
if iv.LastReport() == nil {
|
||||
t.Error("expected report after verify")
|
||||
}
|
||||
}
|
||||
|
||||
// Pluggable integrity check in PreservationEngine.
|
||||
func TestPreservation_IntegrityCheck(t *testing.T) {
|
||||
pe := NewPreservationEngine(func(_ EmergencyMode, _ string, _ map[string]interface{}) error { return nil })
|
||||
|
||||
// Default: no integrity fn → VERIFIED.
|
||||
report := pe.CheckIntegrity()
|
||||
if report.Overall != IntegrityVerified {
|
||||
t.Errorf("expected VERIFIED, got %s", report.Overall)
|
||||
}
|
||||
|
||||
// Set custom checker.
|
||||
pe.SetIntegrityCheck(func() IntegrityReport {
|
||||
return IntegrityReport{Overall: IntegrityCompromised, Timestamp: time.Now()}
|
||||
})
|
||||
|
||||
report = pe.CheckIntegrity()
|
||||
if report.Overall != IntegrityCompromised {
|
||||
t.Errorf("expected COMPROMISED from custom checker, got %s", report.Overall)
|
||||
}
|
||||
}
|
||||
398
internal/application/resilience/recovery_playbooks.go
Normal file
398
internal/application/resilience/recovery_playbooks.go
Normal file
|
|
@ -0,0 +1,398 @@
|
|||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PlaybookStatus tracks the state of a running playbook.
|
||||
type PlaybookStatus string
|
||||
|
||||
const (
|
||||
PlaybookPending PlaybookStatus = "PENDING"
|
||||
PlaybookRunning PlaybookStatus = "RUNNING"
|
||||
PlaybookSucceeded PlaybookStatus = "SUCCEEDED"
|
||||
PlaybookFailed PlaybookStatus = "FAILED"
|
||||
PlaybookRolledBack PlaybookStatus = "ROLLED_BACK"
|
||||
)
|
||||
|
||||
// PlaybookStep is a single step in a recovery playbook.
|
||||
type PlaybookStep struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"` // shell, api, consensus, crypto, systemd, http, prometheus
|
||||
Timeout time.Duration `json:"timeout"`
|
||||
Retries int `json:"retries"`
|
||||
Params map[string]interface{} `json:"params,omitempty"`
|
||||
OnError string `json:"on_error"` // abort, continue, rollback
|
||||
Condition string `json:"condition,omitempty"` // prerequisite condition
|
||||
}
|
||||
|
||||
// Playbook defines a complete recovery procedure.
|
||||
type Playbook struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
TriggerMetric string `json:"trigger_metric"`
|
||||
TriggerSeverity string `json:"trigger_severity"`
|
||||
DiagnosisChecks []PlaybookStep `json:"diagnosis_checks"`
|
||||
Actions []PlaybookStep `json:"actions"`
|
||||
RollbackActions []PlaybookStep `json:"rollback_actions"`
|
||||
SuccessCriteria []string `json:"success_criteria"`
|
||||
}
|
||||
|
||||
// PlaybookExecution tracks a single playbook run.
|
||||
type PlaybookExecution struct {
|
||||
ID string `json:"id"`
|
||||
PlaybookID string `json:"playbook_id"`
|
||||
Component string `json:"component"`
|
||||
Status PlaybookStatus `json:"status"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
CompletedAt time.Time `json:"completed_at,omitempty"`
|
||||
StepsRun []StepResult `json:"steps_run"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// StepResult records the execution of a single playbook step.
|
||||
type StepResult struct {
|
||||
StepID string `json:"step_id"`
|
||||
StepName string `json:"step_name"`
|
||||
Success bool `json:"success"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
Output string `json:"output,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// PlaybookExecutorFunc runs a single playbook step.
|
||||
type PlaybookExecutorFunc func(ctx context.Context, step PlaybookStep, component string) (string, error)
|
||||
|
||||
// RecoveryPlaybookEngine manages and executes recovery playbooks.
|
||||
type RecoveryPlaybookEngine struct {
|
||||
mu sync.RWMutex
|
||||
playbooks map[string]*Playbook
|
||||
executions []*PlaybookExecution
|
||||
execCount int64
|
||||
executor PlaybookExecutorFunc
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewRecoveryPlaybookEngine creates a new playbook engine.
|
||||
func NewRecoveryPlaybookEngine(executor PlaybookExecutorFunc) *RecoveryPlaybookEngine {
|
||||
return &RecoveryPlaybookEngine{
|
||||
playbooks: make(map[string]*Playbook),
|
||||
executions: make([]*PlaybookExecution, 0),
|
||||
executor: executor,
|
||||
logger: slog.Default().With("component", "sarl-recovery-playbooks"),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterPlaybook adds a playbook to the engine.
|
||||
func (rpe *RecoveryPlaybookEngine) RegisterPlaybook(pb Playbook) {
|
||||
rpe.mu.Lock()
|
||||
defer rpe.mu.Unlock()
|
||||
rpe.playbooks[pb.ID] = &pb
|
||||
rpe.logger.Info("playbook registered", "id", pb.ID, "name", pb.Name)
|
||||
}
|
||||
|
||||
// Execute runs a playbook for a given component. Returns the execution ID.
|
||||
func (rpe *RecoveryPlaybookEngine) Execute(ctx context.Context, playbookID, component string) (string, error) {
|
||||
rpe.mu.Lock()
|
||||
pb, ok := rpe.playbooks[playbookID]
|
||||
if !ok {
|
||||
rpe.mu.Unlock()
|
||||
return "", fmt.Errorf("playbook %s not found", playbookID)
|
||||
}
|
||||
|
||||
rpe.execCount++
|
||||
exec := &PlaybookExecution{
|
||||
ID: fmt.Sprintf("exec-%d", rpe.execCount),
|
||||
PlaybookID: playbookID,
|
||||
Component: component,
|
||||
Status: PlaybookRunning,
|
||||
StartedAt: time.Now(),
|
||||
StepsRun: make([]StepResult, 0),
|
||||
}
|
||||
rpe.executions = append(rpe.executions, exec)
|
||||
rpe.mu.Unlock()
|
||||
|
||||
rpe.logger.Info("playbook execution started",
|
||||
"exec_id", exec.ID,
|
||||
"playbook", pb.Name,
|
||||
"component", component,
|
||||
)
|
||||
|
||||
// Phase 1: Diagnosis checks.
|
||||
for _, check := range pb.DiagnosisChecks {
|
||||
result := rpe.runStep(ctx, check, component)
|
||||
exec.StepsRun = append(exec.StepsRun, result)
|
||||
if !result.Success {
|
||||
rpe.logger.Warn("diagnosis check failed",
|
||||
"step", check.ID,
|
||||
"error", result.Error,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: Execute recovery actions.
|
||||
var execErr error
|
||||
for _, action := range pb.Actions {
|
||||
result := rpe.runStep(ctx, action, component)
|
||||
exec.StepsRun = append(exec.StepsRun, result)
|
||||
|
||||
if !result.Success {
|
||||
switch action.OnError {
|
||||
case "continue":
|
||||
continue
|
||||
case "rollback":
|
||||
execErr = fmt.Errorf("step %s failed (rollback): %s", action.ID, result.Error)
|
||||
default: // "abort"
|
||||
execErr = fmt.Errorf("step %s failed: %s", action.ID, result.Error)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: Handle result.
|
||||
if execErr != nil {
|
||||
rpe.logger.Error("playbook failed, executing rollback",
|
||||
"exec_id", exec.ID,
|
||||
"error", execErr,
|
||||
)
|
||||
|
||||
// Execute rollback.
|
||||
for _, rb := range pb.RollbackActions {
|
||||
result := rpe.runStep(ctx, rb, component)
|
||||
exec.StepsRun = append(exec.StepsRun, result)
|
||||
}
|
||||
|
||||
exec.Status = PlaybookRolledBack
|
||||
exec.Error = execErr.Error()
|
||||
} else {
|
||||
exec.Status = PlaybookSucceeded
|
||||
rpe.logger.Info("playbook succeeded",
|
||||
"exec_id", exec.ID,
|
||||
"component", component,
|
||||
"duration", time.Since(exec.StartedAt),
|
||||
)
|
||||
}
|
||||
|
||||
exec.CompletedAt = time.Now()
|
||||
return exec.ID, execErr
|
||||
}
|
||||
|
||||
// runStep executes a single step with timeout and retries.
|
||||
func (rpe *RecoveryPlaybookEngine) runStep(ctx context.Context, step PlaybookStep, component string) StepResult {
|
||||
start := time.Now()
|
||||
result := StepResult{
|
||||
StepID: step.ID,
|
||||
StepName: step.Name,
|
||||
}
|
||||
|
||||
retries := step.Retries
|
||||
if retries <= 0 {
|
||||
retries = 1
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < retries; attempt++ {
|
||||
stepCtx := ctx
|
||||
var cancel context.CancelFunc
|
||||
if step.Timeout > 0 {
|
||||
stepCtx, cancel = context.WithTimeout(ctx, step.Timeout)
|
||||
}
|
||||
|
||||
output, err := rpe.executor(stepCtx, step, component)
|
||||
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
result.Success = true
|
||||
result.Output = output
|
||||
result.Duration = time.Since(start)
|
||||
return result
|
||||
}
|
||||
lastErr = err
|
||||
|
||||
if attempt < retries-1 {
|
||||
rpe.logger.Warn("step retry",
|
||||
"step", step.ID,
|
||||
"attempt", attempt+1,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
result.Success = false
|
||||
result.Error = lastErr.Error()
|
||||
result.Duration = time.Since(start)
|
||||
return result
|
||||
}
|
||||
|
||||
// GetExecution returns a playbook execution by ID.
|
||||
// Returns a deep copy to prevent data races with the execution goroutine.
|
||||
func (rpe *RecoveryPlaybookEngine) GetExecution(id string) (*PlaybookExecution, bool) {
|
||||
rpe.mu.RLock()
|
||||
defer rpe.mu.RUnlock()
|
||||
|
||||
for _, exec := range rpe.executions {
|
||||
if exec.ID == id {
|
||||
cp := *exec
|
||||
cp.StepsRun = make([]StepResult, len(exec.StepsRun))
|
||||
copy(cp.StepsRun, exec.StepsRun)
|
||||
return &cp, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// RecentExecutions returns the last N executions.
|
||||
// Returns deep copies to prevent data races with the execution goroutine.
|
||||
func (rpe *RecoveryPlaybookEngine) RecentExecutions(n int) []PlaybookExecution {
|
||||
rpe.mu.RLock()
|
||||
defer rpe.mu.RUnlock()
|
||||
|
||||
total := len(rpe.executions)
|
||||
if total == 0 {
|
||||
return nil
|
||||
}
|
||||
start := total - n
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
|
||||
result := make([]PlaybookExecution, 0, n)
|
||||
for i := start; i < total; i++ {
|
||||
cp := *rpe.executions[i]
|
||||
cp.StepsRun = make([]StepResult, len(rpe.executions[i].StepsRun))
|
||||
copy(cp.StepsRun, rpe.executions[i].StepsRun)
|
||||
result = append(result, cp)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// PlaybookCount returns the number of registered playbooks.
|
||||
func (rpe *RecoveryPlaybookEngine) PlaybookCount() int {
|
||||
rpe.mu.RLock()
|
||||
defer rpe.mu.RUnlock()
|
||||
return len(rpe.playbooks)
|
||||
}
|
||||
|
||||
// --- Built-in playbooks per ТЗ §7.1 ---
|
||||
|
||||
// DefaultPlaybooks returns the 3 built-in recovery playbooks.
|
||||
func DefaultPlaybooks() []Playbook {
|
||||
return []Playbook{
|
||||
ComponentResurrectionPlaybook(),
|
||||
ConsensusRecoveryPlaybook(),
|
||||
CryptoRotationPlaybook(),
|
||||
}
|
||||
}
|
||||
|
||||
// ComponentResurrectionPlaybook per ТЗ §7.1.1.
|
||||
func ComponentResurrectionPlaybook() Playbook {
|
||||
return Playbook{
|
||||
ID: "component-resurrection",
|
||||
Name: "Component Resurrection",
|
||||
Version: "1.0",
|
||||
TriggerMetric: "component_offline",
|
||||
TriggerSeverity: "CRITICAL",
|
||||
DiagnosisChecks: []PlaybookStep{
|
||||
{ID: "diag-process", Name: "Check process exists", Type: "shell", Timeout: 5 * time.Second},
|
||||
{ID: "diag-crashes", Name: "Check recent crashes", Type: "shell", Timeout: 5 * time.Second},
|
||||
{ID: "diag-resources", Name: "Check resource exhaustion", Type: "prometheus", Timeout: 5 * time.Second},
|
||||
{ID: "diag-deps", Name: "Check dependency health", Type: "http", Timeout: 10 * time.Second},
|
||||
},
|
||||
Actions: []PlaybookStep{
|
||||
{ID: "capture-forensics", Name: "Capture forensics", Type: "shell", Timeout: 30 * time.Second, OnError: "continue"},
|
||||
{ID: "clear-resources", Name: "Clear temp resources", Type: "shell", Timeout: 10 * time.Second, OnError: "continue"},
|
||||
{ID: "restart-component", Name: "Restart component", Type: "systemd", Timeout: 60 * time.Second, OnError: "abort"},
|
||||
{ID: "verify-health", Name: "Verify health", Type: "http", Timeout: 30 * time.Second, Retries: 3, OnError: "abort"},
|
||||
{ID: "verify-metrics", Name: "Verify metrics", Type: "prometheus", Timeout: 30 * time.Second, OnError: "continue"},
|
||||
{ID: "notify-success", Name: "Notify SOC", Type: "api", Timeout: 5 * time.Second, OnError: "continue"},
|
||||
},
|
||||
RollbackActions: []PlaybookStep{
|
||||
{ID: "rb-safe-mode", Name: "Enter safe mode", Type: "api", Timeout: 10 * time.Second},
|
||||
{ID: "rb-notify", Name: "Notify architect", Type: "api", Timeout: 5 * time.Second},
|
||||
},
|
||||
SuccessCriteria: []string{
|
||||
"component_status == HEALTHY",
|
||||
"health_check_passed == true",
|
||||
"no_crashes_for_5min == true",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ConsensusRecoveryPlaybook per ТЗ §7.1.2.
|
||||
func ConsensusRecoveryPlaybook() Playbook {
|
||||
return Playbook{
|
||||
ID: "consensus-recovery",
|
||||
Name: "Distributed Consensus Recovery",
|
||||
Version: "1.0",
|
||||
TriggerMetric: "split_brain",
|
||||
TriggerSeverity: "CRITICAL",
|
||||
DiagnosisChecks: []PlaybookStep{
|
||||
{ID: "diag-peers", Name: "Check peer connectivity", Type: "api", Timeout: 10 * time.Second},
|
||||
{ID: "diag-sync", Name: "Check sync status", Type: "api", Timeout: 10 * time.Second},
|
||||
{ID: "diag-genome", Name: "Verify genome", Type: "api", Timeout: 5 * time.Second},
|
||||
},
|
||||
Actions: []PlaybookStep{
|
||||
{ID: "pause-writes", Name: "Pause all writes", Type: "api", Timeout: 10 * time.Second, OnError: "abort"},
|
||||
{ID: "elect-leader", Name: "Elect leader (Raft)", Type: "consensus", Timeout: 60 * time.Second, OnError: "abort"},
|
||||
{ID: "sync-state", Name: "Sync state from leader", Type: "api", Timeout: 300 * time.Second, OnError: "rollback"},
|
||||
{ID: "verify-consistency", Name: "Verify consistency", Type: "api", Timeout: 60 * time.Second, OnError: "abort"},
|
||||
{ID: "resume-writes", Name: "Resume writes", Type: "api", Timeout: 10 * time.Second, OnError: "abort"},
|
||||
{ID: "notify-cluster", Name: "Notify cluster", Type: "api", Timeout: 5 * time.Second, OnError: "continue"},
|
||||
},
|
||||
RollbackActions: []PlaybookStep{
|
||||
{ID: "rb-readonly", Name: "Maintain readonly", Type: "api", Timeout: 10 * time.Second},
|
||||
{ID: "rb-notify", Name: "Notify architect", Type: "api", Timeout: 5 * time.Second},
|
||||
},
|
||||
SuccessCriteria: []string{
|
||||
"leader_elected == true",
|
||||
"state_synced == true",
|
||||
"consistency_verified == true",
|
||||
"writes_resumed == true",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// CryptoRotationPlaybook per ТЗ §7.1.3.
|
||||
func CryptoRotationPlaybook() Playbook {
|
||||
return Playbook{
|
||||
ID: "crypto-rotation",
|
||||
Name: "Cryptographic Key Rotation",
|
||||
Version: "1.0",
|
||||
TriggerMetric: "key_compromise",
|
||||
TriggerSeverity: "HIGH",
|
||||
DiagnosisChecks: []PlaybookStep{
|
||||
{ID: "diag-key-age", Name: "Check key age", Type: "crypto", Timeout: 5 * time.Second},
|
||||
{ID: "diag-usage", Name: "Check key usage anomaly", Type: "prometheus", Timeout: 5 * time.Second},
|
||||
{ID: "diag-tpm", Name: "Check TPM health", Type: "shell", Timeout: 5 * time.Second},
|
||||
},
|
||||
Actions: []PlaybookStep{
|
||||
{ID: "gen-keys", Name: "Generate new keys", Type: "crypto", Timeout: 30 * time.Second, OnError: "abort",
|
||||
Params: map[string]interface{}{"algorithm": "ECDSA-P256"},
|
||||
},
|
||||
{ID: "rotate-certs", Name: "Rotate mTLS certs", Type: "crypto", Timeout: 120 * time.Second, OnError: "rollback"},
|
||||
{ID: "resign-chain", Name: "Re-sign decision chain", Type: "crypto", Timeout: 300 * time.Second, OnError: "continue"},
|
||||
{ID: "verify-peers", Name: "Verify peer certs", Type: "api", Timeout: 60 * time.Second, OnError: "abort"},
|
||||
{ID: "revoke-old", Name: "Revoke old keys", Type: "crypto", Timeout: 30 * time.Second, OnError: "continue"},
|
||||
{ID: "notify-soc", Name: "Notify SOC", Type: "api", Timeout: 5 * time.Second, OnError: "continue"},
|
||||
},
|
||||
RollbackActions: []PlaybookStep{
|
||||
{ID: "rb-revert-keys", Name: "Revert to previous keys", Type: "crypto", Timeout: 30 * time.Second},
|
||||
{ID: "rb-notify", Name: "Notify architect", Type: "api", Timeout: 5 * time.Second},
|
||||
},
|
||||
SuccessCriteria: []string{
|
||||
"new_keys_generated == true",
|
||||
"certs_distributed == true",
|
||||
"peers_verified == true",
|
||||
"old_keys_revoked == true",
|
||||
},
|
||||
}
|
||||
}
|
||||
318
internal/application/resilience/recovery_playbooks_test.go
Normal file
318
internal/application/resilience/recovery_playbooks_test.go
Normal file
|
|
@ -0,0 +1,318 @@
|
|||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- Mock playbook executor ---
|
||||
|
||||
type mockPlaybookExecutor struct {
|
||||
failSteps map[string]bool
|
||||
callCount int
|
||||
}
|
||||
|
||||
func newMockPlaybookExecutor() *mockPlaybookExecutor {
|
||||
return &mockPlaybookExecutor{failSteps: make(map[string]bool)}
|
||||
}
|
||||
|
||||
func (m *mockPlaybookExecutor) execute(_ context.Context, step PlaybookStep, _ string) (string, error) {
|
||||
m.callCount++
|
||||
if m.failSteps[step.ID] {
|
||||
return "", fmt.Errorf("step %s failed", step.ID)
|
||||
}
|
||||
return fmt.Sprintf("step %s completed", step.ID), nil
|
||||
}
|
||||
|
||||
// --- Recovery Playbook Tests ---
|
||||
|
||||
// AR-01: Component resurrection (success).
|
||||
func TestPlaybook_AR01_ResurrectionSuccess(t *testing.T) {
|
||||
mock := newMockPlaybookExecutor()
|
||||
rpe := NewRecoveryPlaybookEngine(mock.execute)
|
||||
rpe.RegisterPlaybook(ComponentResurrectionPlaybook())
|
||||
|
||||
execID, err := rpe.Execute(context.Background(), "component-resurrection", "soc-ingest")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
exec, ok := rpe.GetExecution(execID)
|
||||
if !ok {
|
||||
t.Fatal("execution not found")
|
||||
}
|
||||
if exec.Status != PlaybookSucceeded {
|
||||
t.Errorf("expected SUCCEEDED, got %s", exec.Status)
|
||||
}
|
||||
if len(exec.StepsRun) == 0 {
|
||||
t.Error("expected steps to be recorded")
|
||||
}
|
||||
}
|
||||
|
||||
// AR-02: Component resurrection (failure → rollback).
|
||||
func TestPlaybook_AR02_ResurrectionFailure(t *testing.T) {
|
||||
mock := newMockPlaybookExecutor()
|
||||
mock.failSteps["restart-component"] = true
|
||||
|
||||
rpe := NewRecoveryPlaybookEngine(mock.execute)
|
||||
rpe.RegisterPlaybook(ComponentResurrectionPlaybook())
|
||||
|
||||
_, err := rpe.Execute(context.Background(), "component-resurrection", "soc-ingest")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
||||
execs := rpe.RecentExecutions(10)
|
||||
if len(execs) == 0 {
|
||||
t.Fatal("expected execution")
|
||||
}
|
||||
if execs[0].Status != PlaybookRolledBack {
|
||||
t.Errorf("expected ROLLED_BACK, got %s", execs[0].Status)
|
||||
}
|
||||
}
|
||||
|
||||
// AR-03: Consensus recovery (success).
|
||||
func TestPlaybook_AR03_ConsensusSuccess(t *testing.T) {
|
||||
mock := newMockPlaybookExecutor()
|
||||
rpe := NewRecoveryPlaybookEngine(mock.execute)
|
||||
rpe.RegisterPlaybook(ConsensusRecoveryPlaybook())
|
||||
|
||||
_, err := rpe.Execute(context.Background(), "consensus-recovery", "cluster")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// AR-04: Consensus recovery (failure → readonly maintained).
|
||||
func TestPlaybook_AR04_ConsensusFailure(t *testing.T) {
|
||||
mock := newMockPlaybookExecutor()
|
||||
mock.failSteps["elect-leader"] = true
|
||||
|
||||
rpe := NewRecoveryPlaybookEngine(mock.execute)
|
||||
rpe.RegisterPlaybook(ConsensusRecoveryPlaybook())
|
||||
|
||||
_, err := rpe.Execute(context.Background(), "consensus-recovery", "cluster")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
||||
execs := rpe.RecentExecutions(10)
|
||||
if execs[0].Status != PlaybookRolledBack {
|
||||
t.Errorf("expected ROLLED_BACK, got %s", execs[0].Status)
|
||||
}
|
||||
}
|
||||
|
||||
// AR-05: Crypto key rotation (success).
|
||||
func TestPlaybook_AR05_CryptoSuccess(t *testing.T) {
|
||||
mock := newMockPlaybookExecutor()
|
||||
rpe := NewRecoveryPlaybookEngine(mock.execute)
|
||||
rpe.RegisterPlaybook(CryptoRotationPlaybook())
|
||||
|
||||
_, err := rpe.Execute(context.Background(), "crypto-rotation", "system")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// AR-06: Crypto rotation (emergency — cert rotation fails → rollback).
|
||||
func TestPlaybook_AR06_CryptoRollback(t *testing.T) {
|
||||
mock := newMockPlaybookExecutor()
|
||||
mock.failSteps["rotate-certs"] = true
|
||||
|
||||
rpe := NewRecoveryPlaybookEngine(mock.execute)
|
||||
rpe.RegisterPlaybook(CryptoRotationPlaybook())
|
||||
|
||||
_, err := rpe.Execute(context.Background(), "crypto-rotation", "system")
|
||||
if err == nil {
|
||||
t.Fatal("expected error on cert rotation failure")
|
||||
}
|
||||
|
||||
execs := rpe.RecentExecutions(10)
|
||||
// Should have run rollback (revert keys).
|
||||
found := false
|
||||
for _, s := range execs[0].StepsRun {
|
||||
if s.StepID == "rb-revert-keys" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected rollback step rb-revert-keys")
|
||||
}
|
||||
}
|
||||
|
||||
// AR-07: Forensic capture (all steps recorded).
|
||||
func TestPlaybook_AR07_ForensicCapture(t *testing.T) {
|
||||
mock := newMockPlaybookExecutor()
|
||||
rpe := NewRecoveryPlaybookEngine(mock.execute)
|
||||
rpe.RegisterPlaybook(ComponentResurrectionPlaybook())
|
||||
|
||||
execID, _ := rpe.Execute(context.Background(), "component-resurrection", "comp")
|
||||
exec, _ := rpe.GetExecution(execID)
|
||||
|
||||
for _, step := range exec.StepsRun {
|
||||
if step.StepID == "" {
|
||||
t.Error("step missing ID")
|
||||
}
|
||||
if step.StepName == "" {
|
||||
t.Errorf("step %s has empty name", step.StepID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AR-08: Rollback execution on action failure.
|
||||
func TestPlaybook_AR08_RollbackExecution(t *testing.T) {
|
||||
mock := newMockPlaybookExecutor()
|
||||
mock.failSteps["sync-state"] = true // Sync fails → rollback trigger.
|
||||
|
||||
rpe := NewRecoveryPlaybookEngine(mock.execute)
|
||||
rpe.RegisterPlaybook(ConsensusRecoveryPlaybook())
|
||||
|
||||
rpe.Execute(context.Background(), "consensus-recovery", "cluster")
|
||||
|
||||
execs := rpe.RecentExecutions(10)
|
||||
if execs[0].Status != PlaybookRolledBack {
|
||||
t.Errorf("expected ROLLED_BACK, got %s", execs[0].Status)
|
||||
}
|
||||
}
|
||||
|
||||
// AR-09: Step retries.
|
||||
func TestPlaybook_AR09_StepRetries(t *testing.T) {
|
||||
callCount := 0
|
||||
executor := func(_ context.Context, step PlaybookStep, _ string) (string, error) {
|
||||
callCount++
|
||||
if step.ID == "verify-health" && callCount <= 2 {
|
||||
return "", fmt.Errorf("not healthy yet")
|
||||
}
|
||||
return "ok", nil
|
||||
}
|
||||
|
||||
rpe := NewRecoveryPlaybookEngine(executor)
|
||||
rpe.RegisterPlaybook(ComponentResurrectionPlaybook())
|
||||
|
||||
_, err := rpe.Execute(context.Background(), "component-resurrection", "comp")
|
||||
if err != nil {
|
||||
t.Fatalf("expected success after retries: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// AR-10: Playbook not found.
|
||||
func TestPlaybook_AR10_NotFound(t *testing.T) {
|
||||
rpe := NewRecoveryPlaybookEngine(nil)
|
||||
_, err := rpe.Execute(context.Background(), "nonexistent", "comp")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent playbook")
|
||||
}
|
||||
}
|
||||
|
||||
// AR-11: Audit logging (all step timestamps).
|
||||
func TestPlaybook_AR11_AuditTimestamps(t *testing.T) {
|
||||
mock := newMockPlaybookExecutor()
|
||||
rpe := NewRecoveryPlaybookEngine(mock.execute)
|
||||
rpe.RegisterPlaybook(ComponentResurrectionPlaybook())
|
||||
|
||||
execID, _ := rpe.Execute(context.Background(), "component-resurrection", "comp")
|
||||
exec, _ := rpe.GetExecution(execID)
|
||||
|
||||
if exec.StartedAt.IsZero() {
|
||||
t.Error("missing started_at")
|
||||
}
|
||||
if exec.CompletedAt.IsZero() {
|
||||
t.Error("missing completed_at")
|
||||
}
|
||||
}
|
||||
|
||||
// AR-12: OnError=continue skips non-critical failures.
|
||||
func TestPlaybook_AR12_ContinueOnError(t *testing.T) {
|
||||
mock := newMockPlaybookExecutor()
|
||||
mock.failSteps["capture-forensics"] = true // OnError=continue.
|
||||
mock.failSteps["notify-success"] = true // OnError=continue.
|
||||
|
||||
rpe := NewRecoveryPlaybookEngine(mock.execute)
|
||||
rpe.RegisterPlaybook(ComponentResurrectionPlaybook())
|
||||
|
||||
_, err := rpe.Execute(context.Background(), "component-resurrection", "comp")
|
||||
if err != nil {
|
||||
t.Fatalf("expected success despite continue-on-error steps: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// AR-13: Context cancellation.
|
||||
func TestPlaybook_AR13_ContextCancel(t *testing.T) {
|
||||
executor := func(ctx context.Context, _ PlaybookStep, _ string) (string, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
return "ok", nil
|
||||
}
|
||||
}
|
||||
|
||||
rpe := NewRecoveryPlaybookEngine(executor)
|
||||
rpe.RegisterPlaybook(ComponentResurrectionPlaybook())
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately.
|
||||
|
||||
_, err := rpe.Execute(ctx, "component-resurrection", "comp")
|
||||
// May or may not error depending on timing, but should not hang.
|
||||
_ = err
|
||||
}
|
||||
|
||||
// AR-14: DefaultPlaybooks returns 3.
|
||||
func TestPlaybook_AR14_DefaultPlaybooks(t *testing.T) {
|
||||
pbs := DefaultPlaybooks()
|
||||
if len(pbs) != 3 {
|
||||
t.Errorf("expected 3 playbooks, got %d", len(pbs))
|
||||
}
|
||||
|
||||
ids := map[string]bool{}
|
||||
for _, pb := range pbs {
|
||||
if ids[pb.ID] {
|
||||
t.Errorf("duplicate playbook ID: %s", pb.ID)
|
||||
}
|
||||
ids[pb.ID] = true
|
||||
|
||||
if len(pb.Actions) == 0 {
|
||||
t.Errorf("playbook %s has no actions", pb.ID)
|
||||
}
|
||||
if len(pb.SuccessCriteria) == 0 {
|
||||
t.Errorf("playbook %s has no success criteria", pb.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AR-15: PlaybookCount and RecentExecutions.
|
||||
func TestPlaybook_AR15_CountsAndRecent(t *testing.T) {
|
||||
mock := newMockPlaybookExecutor()
|
||||
rpe := NewRecoveryPlaybookEngine(mock.execute)
|
||||
|
||||
if rpe.PlaybookCount() != 0 {
|
||||
t.Error("expected 0")
|
||||
}
|
||||
|
||||
for _, pb := range DefaultPlaybooks() {
|
||||
rpe.RegisterPlaybook(pb)
|
||||
}
|
||||
if rpe.PlaybookCount() != 3 {
|
||||
t.Errorf("expected 3, got %d", rpe.PlaybookCount())
|
||||
}
|
||||
|
||||
// Run two playbooks.
|
||||
rpe.Execute(context.Background(), "component-resurrection", "comp1")
|
||||
rpe.Execute(context.Background(), "crypto-rotation", "comp2")
|
||||
|
||||
recent := rpe.RecentExecutions(1)
|
||||
if len(recent) != 1 {
|
||||
t.Errorf("expected 1 recent, got %d", len(recent))
|
||||
}
|
||||
if recent[0].PlaybookID != "crypto-rotation" {
|
||||
t.Errorf("expected crypto-rotation, got %s", recent[0].PlaybookID)
|
||||
}
|
||||
|
||||
all := rpe.RecentExecutions(100)
|
||||
if len(all) != 2 {
|
||||
t.Errorf("expected 2 total, got %d", len(all))
|
||||
}
|
||||
}
|
||||
278
internal/application/shadow_ai/approval.go
Normal file
278
internal/application/shadow_ai/approval.go
Normal file
|
|
@ -0,0 +1,278 @@
|
|||
package shadow_ai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- Tiered Approval Workflow ---
|
||||
// Implements §6 of the ТЗ: data classification → approval tier → SLA tracking.
|
||||
|
||||
// ApprovalStatus tracks the state of an approval request.
|
||||
type ApprovalStatus string
|
||||
|
||||
const (
|
||||
ApprovalPending ApprovalStatus = "pending"
|
||||
ApprovalApproved ApprovalStatus = "approved"
|
||||
ApprovalDenied ApprovalStatus = "denied"
|
||||
ApprovalExpired ApprovalStatus = "expired"
|
||||
ApprovalAutoApproved ApprovalStatus = "auto_approved"
|
||||
)
|
||||
|
||||
// DefaultApprovalTiers defines the approval requirements per data classification.
|
||||
func DefaultApprovalTiers() []ApprovalTier {
|
||||
return []ApprovalTier{
|
||||
{
|
||||
Name: "Tier 1: Public Data",
|
||||
DataClass: DataPublic,
|
||||
ApprovalNeeded: nil, // Auto-approve
|
||||
SLA: 0,
|
||||
AutoApprove: true,
|
||||
},
|
||||
{
|
||||
Name: "Tier 2: Internal Data",
|
||||
DataClass: DataInternal,
|
||||
ApprovalNeeded: []string{"manager"},
|
||||
SLA: 4 * time.Hour,
|
||||
AutoApprove: false,
|
||||
},
|
||||
{
|
||||
Name: "Tier 3: Confidential Data",
|
||||
DataClass: DataConfidential,
|
||||
ApprovalNeeded: []string{"manager", "soc"},
|
||||
SLA: 24 * time.Hour,
|
||||
AutoApprove: false,
|
||||
},
|
||||
{
|
||||
Name: "Tier 4: Critical Data",
|
||||
DataClass: DataCritical,
|
||||
ApprovalNeeded: []string{"ciso"},
|
||||
SLA: 0, // Manual only, no auto-expire
|
||||
AutoApprove: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ApprovalEngine manages the tiered approval workflow.
|
||||
type ApprovalEngine struct {
|
||||
mu sync.RWMutex
|
||||
tiers []ApprovalTier
|
||||
requests map[string]*ApprovalRequest
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewApprovalEngine creates an engine with default tiers.
|
||||
func NewApprovalEngine() *ApprovalEngine {
|
||||
return &ApprovalEngine{
|
||||
tiers: DefaultApprovalTiers(),
|
||||
requests: make(map[string]*ApprovalRequest),
|
||||
logger: slog.Default().With("component", "shadow-ai-approvals"),
|
||||
}
|
||||
}
|
||||
|
||||
// SubmitRequest creates a new approval request based on data classification.
|
||||
// Returns the request or auto-approves if the tier allows it.
|
||||
func (ae *ApprovalEngine) SubmitRequest(userID, docID string, dataClass DataClassification) *ApprovalRequest {
|
||||
ae.mu.Lock()
|
||||
defer ae.mu.Unlock()
|
||||
|
||||
tier := ae.findTier(dataClass)
|
||||
|
||||
req := &ApprovalRequest{
|
||||
ID: genApprovalID(),
|
||||
DocID: docID,
|
||||
UserID: userID,
|
||||
Tier: tier.Name,
|
||||
DataClass: dataClass,
|
||||
Status: string(ApprovalPending),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Set expiry based on SLA.
|
||||
if tier.SLA > 0 {
|
||||
req.ExpiresAt = req.CreatedAt.Add(tier.SLA)
|
||||
}
|
||||
|
||||
// Auto-approve for public data.
|
||||
if tier.AutoApprove {
|
||||
req.Status = string(ApprovalAutoApproved)
|
||||
req.ApprovedBy = "system"
|
||||
req.ResolvedAt = time.Now()
|
||||
ae.logger.Info("auto-approved",
|
||||
"request_id", req.ID,
|
||||
"user", userID,
|
||||
"data_class", dataClass,
|
||||
)
|
||||
} else {
|
||||
ae.logger.Info("approval required",
|
||||
"request_id", req.ID,
|
||||
"user", userID,
|
||||
"data_class", dataClass,
|
||||
"tier", tier.Name,
|
||||
"approvers", tier.ApprovalNeeded,
|
||||
)
|
||||
}
|
||||
|
||||
ae.requests[req.ID] = req
|
||||
return req
|
||||
}
|
||||
|
||||
// Approve approves a pending request.
|
||||
func (ae *ApprovalEngine) Approve(requestID, approvedBy string) error {
|
||||
ae.mu.Lock()
|
||||
defer ae.mu.Unlock()
|
||||
|
||||
req, ok := ae.requests[requestID]
|
||||
if !ok {
|
||||
return fmt.Errorf("request %s not found", requestID)
|
||||
}
|
||||
|
||||
if req.Status != string(ApprovalPending) {
|
||||
return fmt.Errorf("request %s is not pending (status: %s)", requestID, req.Status)
|
||||
}
|
||||
|
||||
req.Status = string(ApprovalApproved)
|
||||
req.ApprovedBy = approvedBy
|
||||
req.ResolvedAt = time.Now()
|
||||
|
||||
ae.logger.Info("approved",
|
||||
"request_id", requestID,
|
||||
"approved_by", approvedBy,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deny denies a pending request.
|
||||
func (ae *ApprovalEngine) Deny(requestID, deniedBy, reason string) error {
|
||||
ae.mu.Lock()
|
||||
defer ae.mu.Unlock()
|
||||
|
||||
req, ok := ae.requests[requestID]
|
||||
if !ok {
|
||||
return fmt.Errorf("request %s not found", requestID)
|
||||
}
|
||||
|
||||
if req.Status != string(ApprovalPending) {
|
||||
return fmt.Errorf("request %s is not pending (status: %s)", requestID, req.Status)
|
||||
}
|
||||
|
||||
req.Status = string(ApprovalDenied)
|
||||
req.DeniedBy = deniedBy
|
||||
req.Reason = reason
|
||||
req.ResolvedAt = time.Now()
|
||||
|
||||
ae.logger.Info("denied",
|
||||
"request_id", requestID,
|
||||
"denied_by", deniedBy,
|
||||
"reason", reason,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRequest returns an approval request by ID.
|
||||
func (ae *ApprovalEngine) GetRequest(requestID string) (*ApprovalRequest, bool) {
|
||||
ae.mu.RLock()
|
||||
defer ae.mu.RUnlock()
|
||||
req, ok := ae.requests[requestID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
cp := *req
|
||||
return &cp, true
|
||||
}
|
||||
|
||||
// PendingRequests returns all pending approval requests.
|
||||
func (ae *ApprovalEngine) PendingRequests() []ApprovalRequest {
|
||||
ae.mu.RLock()
|
||||
defer ae.mu.RUnlock()
|
||||
|
||||
var result []ApprovalRequest
|
||||
for _, req := range ae.requests {
|
||||
if req.Status == string(ApprovalPending) {
|
||||
result = append(result, *req)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ExpireOverdue marks overdue pending requests as expired.
|
||||
// Returns the number of expired requests.
|
||||
func (ae *ApprovalEngine) ExpireOverdue() int {
|
||||
ae.mu.Lock()
|
||||
defer ae.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
expired := 0
|
||||
|
||||
for _, req := range ae.requests {
|
||||
if req.Status == string(ApprovalPending) && !req.ExpiresAt.IsZero() && now.After(req.ExpiresAt) {
|
||||
req.Status = string(ApprovalExpired)
|
||||
req.ResolvedAt = now
|
||||
expired++
|
||||
ae.logger.Warn("request expired",
|
||||
"request_id", req.ID,
|
||||
"user", req.UserID,
|
||||
"expired_at", req.ExpiresAt,
|
||||
)
|
||||
}
|
||||
}
|
||||
return expired
|
||||
}
|
||||
|
||||
// Stats returns approval workflow statistics.
|
||||
func (ae *ApprovalEngine) Stats() map[string]int {
|
||||
ae.mu.RLock()
|
||||
defer ae.mu.RUnlock()
|
||||
|
||||
stats := map[string]int{
|
||||
"total": len(ae.requests),
|
||||
"pending": 0,
|
||||
"approved": 0,
|
||||
"denied": 0,
|
||||
"expired": 0,
|
||||
"auto_approved": 0,
|
||||
}
|
||||
for _, req := range ae.requests {
|
||||
switch ApprovalStatus(req.Status) {
|
||||
case ApprovalPending:
|
||||
stats["pending"]++
|
||||
case ApprovalApproved:
|
||||
stats["approved"]++
|
||||
case ApprovalDenied:
|
||||
stats["denied"]++
|
||||
case ApprovalExpired:
|
||||
stats["expired"]++
|
||||
case ApprovalAutoApproved:
|
||||
stats["auto_approved"]++
|
||||
}
|
||||
}
|
||||
return stats
|
||||
}
|
||||
|
||||
// Tiers returns the approval tier configuration.
|
||||
func (ae *ApprovalEngine) Tiers() []ApprovalTier {
|
||||
return ae.tiers
|
||||
}
|
||||
|
||||
func (ae *ApprovalEngine) findTier(dataClass DataClassification) ApprovalTier {
|
||||
for _, t := range ae.tiers {
|
||||
if t.DataClass == dataClass {
|
||||
return t
|
||||
}
|
||||
}
|
||||
// Default to most restrictive.
|
||||
return ae.tiers[len(ae.tiers)-1]
|
||||
}
|
||||
|
||||
var approvalCounter uint64
|
||||
var approvalCounterMu sync.Mutex
|
||||
|
||||
func genApprovalID() string {
|
||||
approvalCounterMu.Lock()
|
||||
approvalCounter++
|
||||
id := approvalCounter
|
||||
approvalCounterMu.Unlock()
|
||||
return fmt.Sprintf("apr-%d-%d", time.Now().UnixMilli(), id)
|
||||
}
|
||||
116
internal/application/shadow_ai/correlation.go
Normal file
116
internal/application/shadow_ai/correlation.go
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
package shadow_ai
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
domsoc "github.com/syntrex/gomcp/internal/domain/soc"
|
||||
)
|
||||
|
||||
// ShadowAICorrelationRules returns SOC correlation rules specific to Shadow AI
|
||||
// detection. These integrate into the existing SOC correlation engine.
|
||||
func ShadowAICorrelationRules() []domsoc.SOCCorrelationRule {
|
||||
return []domsoc.SOCCorrelationRule{
|
||||
{
|
||||
ID: "SAI-CR-001",
|
||||
Name: "Multi-Service Shadow AI",
|
||||
RequiredCategories: []string{"shadow_ai_usage"},
|
||||
MinEvents: 3,
|
||||
TimeWindow: 10 * time.Minute,
|
||||
Severity: domsoc.SeverityHigh,
|
||||
KillChainPhase: "Reconnaissance",
|
||||
MITREMapping: []string{"T1595"},
|
||||
Description: "User accessing 3+ distinct AI services within 10 minutes. Indicates active AI tool exploration or data shopping across providers.",
|
||||
},
|
||||
{
|
||||
ID: "SAI-CR-002",
|
||||
Name: "Shadow AI + Data Exfiltration",
|
||||
RequiredCategories: []string{"shadow_ai_usage", "exfiltration"},
|
||||
MinEvents: 2,
|
||||
TimeWindow: 15 * time.Minute,
|
||||
Severity: domsoc.SeverityCritical,
|
||||
KillChainPhase: "Exfiltration",
|
||||
MITREMapping: []string{"T1041", "T1567"},
|
||||
Description: "Shadow AI usage followed by data exfiltration attempt. Possible corporate data leakage via unauthorized AI services.",
|
||||
},
|
||||
{
|
||||
ID: "SAI-CR-003",
|
||||
Name: "Shadow AI Volume Spike",
|
||||
RequiredCategories: []string{"shadow_ai_usage"},
|
||||
MinEvents: 10,
|
||||
TimeWindow: 1 * time.Hour,
|
||||
Severity: domsoc.SeverityHigh,
|
||||
KillChainPhase: "Actions on Objectives",
|
||||
MITREMapping: []string{"T1048"},
|
||||
Description: "10+ shadow AI events from same source within 1 hour. Indicates bulk data transfer to external AI service.",
|
||||
},
|
||||
{
|
||||
ID: "SAI-CR-004",
|
||||
Name: "Shadow AI After Hours",
|
||||
RequiredCategories: []string{"shadow_ai_usage"},
|
||||
MinEvents: 2,
|
||||
TimeWindow: 30 * time.Minute,
|
||||
Severity: domsoc.SeverityMedium,
|
||||
KillChainPhase: "Persistence",
|
||||
MITREMapping: []string{"T1053"},
|
||||
Description: "Shadow AI usage outside business hours (detected via timestamp clustering). May indicate automated scripts or insider threat.",
|
||||
},
|
||||
{
|
||||
ID: "SAI-CR-005",
|
||||
Name: "Integration Failure Chain",
|
||||
RequiredCategories: []string{"integration_health"},
|
||||
MinEvents: 3,
|
||||
TimeWindow: 5 * time.Minute,
|
||||
Severity: domsoc.SeverityCritical,
|
||||
KillChainPhase: "Defense Evasion",
|
||||
MITREMapping: []string{"T1562"},
|
||||
Description: "3+ integration health failures in 5 minutes. Possible attack on enforcement infrastructure to blind Shadow AI detection.",
|
||||
},
|
||||
{
|
||||
ID: "SAI-CR-006",
|
||||
Name: "Shadow AI + PII Leak",
|
||||
RequiredCategories: []string{"shadow_ai_usage", "pii_leak"},
|
||||
MinEvents: 2,
|
||||
TimeWindow: 10 * time.Minute,
|
||||
Severity: domsoc.SeverityCritical,
|
||||
KillChainPhase: "Exfiltration",
|
||||
MITREMapping: []string{"T1567.002"},
|
||||
Description: "Shadow AI usage combined with PII leak detection. GDPR/regulatory violation in progress — immediate response required.",
|
||||
},
|
||||
{
|
||||
ID: "SAI-CR-007",
|
||||
Name: "Shadow AI Evasion Attempt",
|
||||
SequenceCategories: []string{"shadow_ai_usage", "evasion"},
|
||||
MinEvents: 2,
|
||||
TimeWindow: 10 * time.Minute,
|
||||
Severity: domsoc.SeverityHigh,
|
||||
KillChainPhase: "Defense Evasion",
|
||||
MITREMapping: []string{"T1090", "T1573"},
|
||||
Description: "Shadow AI usage followed by evasion technique (VPN, proxy chaining, encoding). User attempting to bypass detection.",
|
||||
},
|
||||
{
|
||||
ID: "SAI-CR-008",
|
||||
Name: "Cross-Department AI Usage",
|
||||
RequiredCategories: []string{"shadow_ai_usage"},
|
||||
MinEvents: 5,
|
||||
TimeWindow: 30 * time.Minute,
|
||||
Severity: domsoc.SeverityMedium,
|
||||
CrossSensor: true,
|
||||
KillChainPhase: "Lateral Movement",
|
||||
MITREMapping: []string{"T1021"},
|
||||
Description: "Shadow AI events from 5+ distinct network segments/sensors within 30 minutes. Indicates coordinated policy circumvention or compromised credentials used across departments.",
|
||||
},
|
||||
// Severity trend: escalating shadow AI event severity
|
||||
{
|
||||
ID: "SAI-CR-009",
|
||||
Name: "Shadow AI Escalation",
|
||||
SeverityTrend: "ascending",
|
||||
TrendCategory: "shadow_ai_usage",
|
||||
MinEvents: 3,
|
||||
TimeWindow: 30 * time.Minute,
|
||||
Severity: domsoc.SeverityCritical,
|
||||
KillChainPhase: "Exploitation",
|
||||
MITREMapping: []string{"T1059"},
|
||||
Description: "Ascending severity pattern in Shadow AI events: user escalating from casual browsing to bulk data uploads. Crescendo data theft in progress.",
|
||||
},
|
||||
}
|
||||
}
|
||||
503
internal/application/shadow_ai/detection.go
Normal file
503
internal/application/shadow_ai/detection.go
Normal file
|
|
@ -0,0 +1,503 @@
|
|||
package shadow_ai
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"math"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- AI Signature Database ---
|
||||
|
||||
// AISignatureDB contains known AI service signatures for detection.
|
||||
type AISignatureDB struct {
|
||||
mu sync.RWMutex
|
||||
services []AIServiceInfo
|
||||
domainPatterns []*domainPattern
|
||||
apiKeyPatterns []*APIKeyPattern
|
||||
httpSignatures []string
|
||||
}
|
||||
|
||||
type domainPattern struct {
|
||||
original string
|
||||
regex *regexp.Regexp
|
||||
service string
|
||||
}
|
||||
|
||||
// APIKeyPattern defines a regex pattern for detecting AI API keys.
|
||||
type APIKeyPattern struct {
|
||||
Name string `json:"name"`
|
||||
Pattern *regexp.Regexp `json:"-"`
|
||||
Entropy float64 `json:"min_entropy"`
|
||||
}
|
||||
|
||||
// NewAISignatureDB creates a signature database pre-loaded with known AI services.
|
||||
func NewAISignatureDB() *AISignatureDB {
|
||||
db := &AISignatureDB{}
|
||||
db.loadDefaults()
|
||||
return db
|
||||
}
|
||||
|
||||
// loadDefaults populates the database with known AI services and patterns.
|
||||
func (db *AISignatureDB) loadDefaults() {
|
||||
db.services = defaultAIServices()
|
||||
|
||||
// Compile domain patterns.
|
||||
for _, svc := range db.services {
|
||||
for _, d := range svc.Domains {
|
||||
pattern := domainToRegex(d)
|
||||
db.domainPatterns = append(db.domainPatterns, &domainPattern{
|
||||
original: d,
|
||||
regex: pattern,
|
||||
service: svc.Name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// API key patterns.
|
||||
db.apiKeyPatterns = defaultAPIKeyPatterns()
|
||||
|
||||
// HTTP header signatures.
|
||||
db.httpSignatures = []string{
|
||||
"authorization: bearer sk-", // OpenAI
|
||||
"authorization: bearer ant-", // Anthropic
|
||||
"x-api-key: sk-ant-", // Anthropic v2
|
||||
"x-goog-api-key:", // Google AI
|
||||
"authorization: bearer gsk_", // Groq
|
||||
"authorization: bearer hf_", // HuggingFace
|
||||
}
|
||||
}
|
||||
|
||||
// domainToRegex converts a wildcard domain (e.g., "*.openai.com") to a regex.
|
||||
func domainToRegex(domain string) *regexp.Regexp {
|
||||
escaped := regexp.QuoteMeta(domain)
|
||||
escaped = strings.ReplaceAll(escaped, `\*`, `[a-zA-Z0-9\-]+`)
|
||||
return regexp.MustCompile("(?i)^" + escaped + "$")
|
||||
}
|
||||
|
||||
// MatchDomain checks if a domain matches any known AI service.
|
||||
// Returns the service name or empty string.
|
||||
func (db *AISignatureDB) MatchDomain(domain string) string {
|
||||
db.mu.RLock()
|
||||
defer db.mu.RUnlock()
|
||||
|
||||
domain = strings.ToLower(strings.TrimSpace(domain))
|
||||
for _, dp := range db.domainPatterns {
|
||||
if dp.regex.MatchString(domain) {
|
||||
return dp.service
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// MatchHTTPHeaders checks if HTTP headers contain known AI service signatures.
|
||||
func (db *AISignatureDB) MatchHTTPHeaders(headers map[string]string) string {
|
||||
db.mu.RLock()
|
||||
defer db.mu.RUnlock()
|
||||
|
||||
for key, value := range headers {
|
||||
headerLine := strings.ToLower(key + ": " + value)
|
||||
for _, sig := range db.httpSignatures {
|
||||
if strings.Contains(headerLine, sig) {
|
||||
return sig
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ScanForAPIKeys scans content for AI API keys.
|
||||
// Returns the matched pattern name or empty string.
|
||||
func (db *AISignatureDB) ScanForAPIKeys(content string) string {
|
||||
db.mu.RLock()
|
||||
defer db.mu.RUnlock()
|
||||
|
||||
for _, pattern := range db.apiKeyPatterns {
|
||||
if pattern.Pattern.MatchString(content) {
|
||||
return pattern.Name
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ServiceCount returns the number of known AI services.
|
||||
func (db *AISignatureDB) ServiceCount() int {
|
||||
db.mu.RLock()
|
||||
defer db.mu.RUnlock()
|
||||
return len(db.services)
|
||||
}
|
||||
|
||||
// DomainPatternCount returns the number of compiled domain patterns.
|
||||
func (db *AISignatureDB) DomainPatternCount() int {
|
||||
db.mu.RLock()
|
||||
defer db.mu.RUnlock()
|
||||
return len(db.domainPatterns)
|
||||
}
|
||||
|
||||
// AddService adds a custom AI service to the database.
|
||||
func (db *AISignatureDB) AddService(svc AIServiceInfo) {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
db.services = append(db.services, svc)
|
||||
for _, d := range svc.Domains {
|
||||
pattern := domainToRegex(d)
|
||||
db.domainPatterns = append(db.domainPatterns, &domainPattern{
|
||||
original: d,
|
||||
regex: pattern,
|
||||
service: svc.Name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Network Detector ---
|
||||
|
||||
// NetworkEvent represents a network connection event for analysis.
|
||||
type NetworkEvent struct {
|
||||
User string `json:"user"`
|
||||
Hostname string `json:"hostname"`
|
||||
Destination string `json:"destination"` // Domain or IP
|
||||
Port int `json:"port"`
|
||||
HTTPHeaders map[string]string `json:"http_headers,omitempty"`
|
||||
TLSJA3 string `json:"tls_ja3,omitempty"`
|
||||
DataSize int64 `json:"data_size"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// NetworkDetector analyzes network events for AI service access.
|
||||
type NetworkDetector struct {
|
||||
signatures *AISignatureDB
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewNetworkDetector creates a new network detector with the default signature DB.
|
||||
func NewNetworkDetector() *NetworkDetector {
|
||||
return &NetworkDetector{
|
||||
signatures: NewAISignatureDB(),
|
||||
logger: slog.Default().With("component", "shadow-ai-network"),
|
||||
}
|
||||
}
|
||||
|
||||
// NewNetworkDetectorWithDB creates a detector with a custom signature database.
|
||||
func NewNetworkDetectorWithDB(db *AISignatureDB) *NetworkDetector {
|
||||
return &NetworkDetector{
|
||||
signatures: db,
|
||||
logger: slog.Default().With("component", "shadow-ai-network"),
|
||||
}
|
||||
}
|
||||
|
||||
// Analyze checks a network event for AI service access.
|
||||
// Returns a ShadowAIEvent if detected, nil otherwise.
|
||||
func (nd *NetworkDetector) Analyze(event NetworkEvent) *ShadowAIEvent {
|
||||
// Check domain match.
|
||||
if service := nd.signatures.MatchDomain(event.Destination); service != "" {
|
||||
nd.logger.Info("AI domain detected",
|
||||
"user", event.User,
|
||||
"destination", event.Destination,
|
||||
"service", service,
|
||||
)
|
||||
return &ShadowAIEvent{
|
||||
UserID: event.User,
|
||||
Hostname: event.Hostname,
|
||||
Destination: event.Destination,
|
||||
AIService: service,
|
||||
DetectionMethod: DetectNetwork,
|
||||
Action: "detected",
|
||||
DataSize: event.DataSize,
|
||||
Timestamp: event.Timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// Check HTTP header signatures.
|
||||
if sig := nd.signatures.MatchHTTPHeaders(event.HTTPHeaders); sig != "" {
|
||||
nd.logger.Info("AI HTTP signature detected",
|
||||
"user", event.User,
|
||||
"destination", event.Destination,
|
||||
"signature", sig,
|
||||
)
|
||||
return &ShadowAIEvent{
|
||||
UserID: event.User,
|
||||
Hostname: event.Hostname,
|
||||
Destination: event.Destination,
|
||||
AIService: "unknown",
|
||||
DetectionMethod: DetectHTTP,
|
||||
Action: "detected",
|
||||
DataSize: event.DataSize,
|
||||
Timestamp: event.Timestamp,
|
||||
Metadata: map[string]string{"http_signature": sig},
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SignatureDB returns the underlying signature database for extension.
|
||||
func (nd *NetworkDetector) SignatureDB() *AISignatureDB {
|
||||
return nd.signatures
|
||||
}
|
||||
|
||||
// --- Behavioral Detector ---
|
||||
|
||||
// UserBehaviorProfile tracks a user's AI access behavior for anomaly detection.
|
||||
type UserBehaviorProfile struct {
|
||||
UserID string `json:"user_id"`
|
||||
AccessFrequency float64 `json:"access_frequency"` // Requests per hour
|
||||
DataVolumePerHour float64 `json:"data_volume_per_hour"` // Bytes per hour
|
||||
KnownDestinations []string `json:"known_destinations"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// BehavioralAlert is emitted when anomalous AI access is detected.
|
||||
type BehavioralAlert struct {
|
||||
UserID string `json:"user_id"`
|
||||
AnomalyType string `json:"anomaly_type"` // "access_spike", "new_destination", "data_volume_spike"
|
||||
Current float64 `json:"current"`
|
||||
Baseline float64 `json:"baseline"`
|
||||
ZScore float64 `json:"z_score"`
|
||||
Destination string `json:"destination,omitempty"`
|
||||
Severity string `json:"severity"`
|
||||
}
|
||||
|
||||
// BehavioralDetector detects anomalous AI usage patterns per user.
|
||||
type BehavioralDetector struct {
|
||||
mu sync.RWMutex
|
||||
baselines map[string]*UserBehaviorProfile
|
||||
current map[string]*UserBehaviorProfile
|
||||
alertBus chan BehavioralAlert
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewBehavioralDetector creates a behavioral detector with a buffered alert bus.
|
||||
func NewBehavioralDetector(alertBufSize int) *BehavioralDetector {
|
||||
if alertBufSize <= 0 {
|
||||
alertBufSize = 100
|
||||
}
|
||||
return &BehavioralDetector{
|
||||
baselines: make(map[string]*UserBehaviorProfile),
|
||||
current: make(map[string]*UserBehaviorProfile),
|
||||
alertBus: make(chan BehavioralAlert, alertBufSize),
|
||||
logger: slog.Default().With("component", "shadow-ai-behavioral"),
|
||||
}
|
||||
}
|
||||
|
||||
// RecordAccess records a single AI access attempt for behavioral tracking.
|
||||
func (bd *BehavioralDetector) RecordAccess(userID, destination string, dataSize int64) {
|
||||
bd.mu.Lock()
|
||||
defer bd.mu.Unlock()
|
||||
|
||||
profile, ok := bd.current[userID]
|
||||
if !ok {
|
||||
profile = &UserBehaviorProfile{
|
||||
UserID: userID,
|
||||
}
|
||||
bd.current[userID] = profile
|
||||
}
|
||||
|
||||
profile.AccessFrequency++
|
||||
profile.DataVolumePerHour += float64(dataSize)
|
||||
profile.UpdatedAt = time.Now()
|
||||
|
||||
// Track destinations.
|
||||
found := false
|
||||
for _, d := range profile.KnownDestinations {
|
||||
if d == destination {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
profile.KnownDestinations = append(profile.KnownDestinations, destination)
|
||||
}
|
||||
}
|
||||
|
||||
// SetBaseline sets the known baseline behavior for a user.
|
||||
func (bd *BehavioralDetector) SetBaseline(userID string, profile *UserBehaviorProfile) {
|
||||
bd.mu.Lock()
|
||||
defer bd.mu.Unlock()
|
||||
bd.baselines[userID] = profile
|
||||
}
|
||||
|
||||
// DetectAnomalies compares current behavior to baselines and emits alerts.
|
||||
func (bd *BehavioralDetector) DetectAnomalies() []BehavioralAlert {
|
||||
bd.mu.RLock()
|
||||
defer bd.mu.RUnlock()
|
||||
|
||||
var alerts []BehavioralAlert
|
||||
|
||||
for userID, current := range bd.current {
|
||||
baseline, ok := bd.baselines[userID]
|
||||
if !ok {
|
||||
// No baseline — any AI access from this user is suspicious.
|
||||
if current.AccessFrequency > 0 {
|
||||
alert := BehavioralAlert{
|
||||
UserID: userID,
|
||||
AnomalyType: "first_ai_access",
|
||||
Current: current.AccessFrequency,
|
||||
Baseline: 0,
|
||||
Severity: "WARNING",
|
||||
}
|
||||
alerts = append(alerts, alert)
|
||||
bd.emitAlert(alert)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Z-score for access frequency.
|
||||
if baseline.AccessFrequency > 0 {
|
||||
zscore := (current.AccessFrequency - baseline.AccessFrequency) / math.Max(baseline.AccessFrequency*0.3, 1)
|
||||
if math.Abs(zscore) > 3.0 {
|
||||
alert := BehavioralAlert{
|
||||
UserID: userID,
|
||||
AnomalyType: "access_spike",
|
||||
Current: current.AccessFrequency,
|
||||
Baseline: baseline.AccessFrequency,
|
||||
ZScore: zscore,
|
||||
Severity: "WARNING",
|
||||
}
|
||||
alerts = append(alerts, alert)
|
||||
bd.emitAlert(alert)
|
||||
}
|
||||
}
|
||||
|
||||
// Detect new AI destinations.
|
||||
for _, dest := range current.KnownDestinations {
|
||||
isNew := true
|
||||
for _, known := range baseline.KnownDestinations {
|
||||
if dest == known {
|
||||
isNew = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if isNew {
|
||||
alert := BehavioralAlert{
|
||||
UserID: userID,
|
||||
AnomalyType: "new_ai_destination",
|
||||
Destination: dest,
|
||||
Severity: "HIGH",
|
||||
}
|
||||
alerts = append(alerts, alert)
|
||||
bd.emitAlert(alert)
|
||||
}
|
||||
}
|
||||
|
||||
// Z-score for data volume.
|
||||
if baseline.DataVolumePerHour > 0 {
|
||||
zscore := (current.DataVolumePerHour - baseline.DataVolumePerHour) / math.Max(baseline.DataVolumePerHour*0.3, 1)
|
||||
if math.Abs(zscore) > 3.0 {
|
||||
alert := BehavioralAlert{
|
||||
UserID: userID,
|
||||
AnomalyType: "data_volume_spike",
|
||||
Current: current.DataVolumePerHour,
|
||||
Baseline: baseline.DataVolumePerHour,
|
||||
ZScore: zscore,
|
||||
Severity: "CRITICAL",
|
||||
}
|
||||
alerts = append(alerts, alert)
|
||||
bd.emitAlert(alert)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return alerts
|
||||
}
|
||||
|
||||
// Alerts returns the alert channel for consuming behavioral alerts.
|
||||
func (bd *BehavioralDetector) Alerts() <-chan BehavioralAlert {
|
||||
return bd.alertBus
|
||||
}
|
||||
|
||||
// ResetCurrent clears the current period data (call after each analysis window).
|
||||
func (bd *BehavioralDetector) ResetCurrent() {
|
||||
bd.mu.Lock()
|
||||
defer bd.mu.Unlock()
|
||||
bd.current = make(map[string]*UserBehaviorProfile)
|
||||
}
|
||||
|
||||
func (bd *BehavioralDetector) emitAlert(alert BehavioralAlert) {
|
||||
select {
|
||||
case bd.alertBus <- alert:
|
||||
default:
|
||||
bd.logger.Warn("behavioral alert bus full, dropping alert",
|
||||
"user", alert.UserID,
|
||||
"type", alert.AnomalyType,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Default Data ---
|
||||
|
||||
func defaultAIServices() []AIServiceInfo {
|
||||
return []AIServiceInfo{
|
||||
{Name: "ChatGPT", Vendor: "OpenAI", Domains: []string{"chat.openai.com", "api.openai.com", "*.openai.com"}, Category: "llm"},
|
||||
{Name: "Claude", Vendor: "Anthropic", Domains: []string{"claude.ai", "api.anthropic.com", "*.anthropic.com"}, Category: "llm"},
|
||||
{Name: "Gemini", Vendor: "Google", Domains: []string{"gemini.google.com", "generativelanguage.googleapis.com", "aistudio.google.com"}, Category: "llm"},
|
||||
{Name: "Copilot", Vendor: "Microsoft", Domains: []string{"copilot.microsoft.com", "*.copilot.microsoft.com"}, Category: "code_assist"},
|
||||
{Name: "Cohere", Vendor: "Cohere", Domains: []string{"api.cohere.ai", "dashboard.cohere.com", "*.cohere.ai"}, Category: "llm"},
|
||||
{Name: "AI21", Vendor: "AI21 Labs", Domains: []string{"api.ai21.com", "studio.ai21.com", "*.ai21.com"}, Category: "llm"},
|
||||
{Name: "HuggingFace", Vendor: "Hugging Face", Domains: []string{"api-inference.huggingface.co", "huggingface.co", "*.huggingface.co"}, Category: "llm"},
|
||||
{Name: "Replicate", Vendor: "Replicate", Domains: []string{"api.replicate.com", "replicate.com", "*.replicate.com"}, Category: "llm"},
|
||||
{Name: "Mistral", Vendor: "Mistral AI", Domains: []string{"api.mistral.ai", "chat.mistral.ai", "*.mistral.ai"}, Category: "llm"},
|
||||
{Name: "Perplexity", Vendor: "Perplexity", Domains: []string{"api.perplexity.ai", "perplexity.ai", "*.perplexity.ai"}, Category: "llm"},
|
||||
{Name: "Groq", Vendor: "Groq", Domains: []string{"api.groq.com", "groq.com", "*.groq.com"}, Category: "llm"},
|
||||
{Name: "Together", Vendor: "Together AI", Domains: []string{"api.together.xyz", "together.ai", "*.together.ai"}, Category: "llm"},
|
||||
{Name: "Stability", Vendor: "Stability AI", Domains: []string{"api.stability.ai", "*.stability.ai"}, Category: "image_gen"},
|
||||
{Name: "Midjourney", Vendor: "Midjourney", Domains: []string{"midjourney.com", "*.midjourney.com"}, Category: "image_gen"},
|
||||
{Name: "DALL-E", Vendor: "OpenAI", Domains: []string{"labs.openai.com"}, Category: "image_gen"},
|
||||
{Name: "Cursor", Vendor: "Cursor", Domains: []string{"api2.cursor.sh", "*.cursor.sh"}, Category: "code_assist"},
|
||||
{Name: "Replit AI", Vendor: "Replit", Domains: []string{"replit.com", "*.replit.com"}, Category: "code_assist"},
|
||||
{Name: "Codeium", Vendor: "Codeium", Domains: []string{"*.codeium.com", "codeium.com"}, Category: "code_assist"},
|
||||
{Name: "Tabnine", Vendor: "Tabnine", Domains: []string{"*.tabnine.com", "tabnine.com"}, Category: "code_assist"},
|
||||
{Name: "Qwen", Vendor: "Alibaba", Domains: []string{"dashscope.aliyuncs.com", "*.dashscope.aliyuncs.com"}, Category: "llm"},
|
||||
{Name: "DeepSeek", Vendor: "DeepSeek", Domains: []string{"api.deepseek.com", "chat.deepseek.com", "*.deepseek.com"}, Category: "llm"},
|
||||
{Name: "Kimi", Vendor: "Moonshot AI", Domains: []string{"api.moonshot.cn", "kimi.moonshot.cn", "*.moonshot.cn"}, Category: "llm"},
|
||||
{Name: "Baidu ERNIE", Vendor: "Baidu", Domains: []string{"aip.baidubce.com", "erniebot.baidu.com"}, Category: "llm"},
|
||||
{Name: "Jasper", Vendor: "Jasper", Domains: []string{"app.jasper.ai", "api.jasper.ai", "*.jasper.ai"}, Category: "llm"},
|
||||
{Name: "Writer", Vendor: "Writer", Domains: []string{"writer.com", "api.writer.com", "*.writer.com"}, Category: "llm"},
|
||||
{Name: "Notion AI", Vendor: "Notion", Domains: []string{"www.notion.so"}, Category: "productivity"},
|
||||
{Name: "Grammarly AI", Vendor: "Grammarly", Domains: []string{"*.grammarly.com"}, Category: "productivity"},
|
||||
{Name: "Runway", Vendor: "Runway", Domains: []string{"app.runwayml.com", "api.runwayml.com", "*.runwayml.com"}, Category: "video_gen"},
|
||||
{Name: "Pika", Vendor: "Pika", Domains: []string{"pika.art", "*.pika.art"}, Category: "video_gen"},
|
||||
{Name: "ElevenLabs", Vendor: "ElevenLabs", Domains: []string{"api.elevenlabs.io", "elevenlabs.io", "*.elevenlabs.io"}, Category: "audio_gen"},
|
||||
{Name: "Suno", Vendor: "Suno", Domains: []string{"suno.com", "*.suno.com"}, Category: "audio_gen"},
|
||||
{Name: "OpenRouter", Vendor: "OpenRouter", Domains: []string{"openrouter.ai", "*.openrouter.ai"}, Category: "llm"},
|
||||
{Name: "Scale AI", Vendor: "Scale", Domains: []string{"scale.com", "api.scale.com", "*.scale.com"}, Category: "llm"},
|
||||
{Name: "Inflection Pi", Vendor: "Inflection", Domains: []string{"pi.ai", "api.inflection.ai"}, Category: "llm"},
|
||||
{Name: "Grok", Vendor: "xAI", Domains: []string{"grok.x.ai", "api.x.ai"}, Category: "llm"},
|
||||
{Name: "Character.AI", Vendor: "Character.AI", Domains: []string{"character.ai", "*.character.ai"}, Category: "llm"},
|
||||
{Name: "Poe", Vendor: "Quora", Domains: []string{"poe.com", "*.poe.com"}, Category: "llm"},
|
||||
{Name: "You.com", Vendor: "You.com", Domains: []string{"you.com", "api.you.com"}, Category: "llm"},
|
||||
{Name: "Phind", Vendor: "Phind", Domains: []string{"phind.com", "*.phind.com"}, Category: "llm"},
|
||||
}
|
||||
}
|
||||
|
||||
func defaultAPIKeyPatterns() []*APIKeyPattern {
|
||||
return []*APIKeyPattern{
|
||||
{Name: "OpenAI API Key", Pattern: regexp.MustCompile(`sk-[a-zA-Z0-9]{20,}T3BlbkFJ[a-zA-Z0-9]{20,}`), Entropy: 4.5},
|
||||
{Name: "OpenAI Project Key", Pattern: regexp.MustCompile(`sk-proj-[a-zA-Z0-9\-_]{48,}`), Entropy: 4.5},
|
||||
{Name: "Anthropic API Key", Pattern: regexp.MustCompile(`sk-ant-[a-zA-Z0-9\-_]{90,}`), Entropy: 4.5},
|
||||
{Name: "Google AI API Key", Pattern: regexp.MustCompile(`AIza[0-9A-Za-z\-_]{35}`), Entropy: 4.0},
|
||||
{Name: "HuggingFace Token", Pattern: regexp.MustCompile(`hf_[a-zA-Z0-9]{34}`), Entropy: 4.5},
|
||||
{Name: "Groq API Key", Pattern: regexp.MustCompile(`gsk_[a-zA-Z0-9]{52}`), Entropy: 4.5},
|
||||
{Name: "Cohere API Key", Pattern: regexp.MustCompile(`[a-zA-Z0-9]{10,}-[a-zA-Z0-9]{4,}-[a-zA-Z0-9]{4,}-[a-zA-Z0-9]{4,}-[a-zA-Z0-9]{12,}`), Entropy: 4.5},
|
||||
{Name: "Replicate API Token", Pattern: regexp.MustCompile(`r8_[a-zA-Z0-9]{37}`), Entropy: 4.5},
|
||||
}
|
||||
}
|
||||
|
||||
// ServicesByCategory returns AI services grouped by category.
|
||||
func ServicesByCategory() map[string][]AIServiceInfo {
|
||||
services := defaultAIServices()
|
||||
result := make(map[string][]AIServiceInfo)
|
||||
for _, svc := range services {
|
||||
result[svc.Category] = append(result[svc.Category], svc)
|
||||
}
|
||||
// Sort each category by name for deterministic output.
|
||||
for cat := range result {
|
||||
sort.Slice(result[cat], func(i, j int) bool {
|
||||
return result[cat][i].Name < result[cat][j].Name
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
353
internal/application/shadow_ai/doc_bridge.go
Normal file
353
internal/application/shadow_ai/doc_bridge.go
Normal file
|
|
@ -0,0 +1,353 @@
|
|||
package shadow_ai
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- Document Review Bridge ---
|
||||
// Controlled gateway for AI access: scans documents for secrets and PII,
|
||||
// supports content redaction, and routes through the approval workflow.
|
||||
|
||||
// DocReviewStatus tracks the lifecycle of a document review.
|
||||
type DocReviewStatus string
|
||||
|
||||
const (
|
||||
DocReviewPending DocReviewStatus = "pending"
|
||||
DocReviewScanning DocReviewStatus = "scanning"
|
||||
DocReviewClean DocReviewStatus = "clean"
|
||||
DocReviewRedacted DocReviewStatus = "redacted"
|
||||
DocReviewBlocked DocReviewStatus = "blocked"
|
||||
DocReviewApproved DocReviewStatus = "approved"
|
||||
)
|
||||
|
||||
// ScanResult contains the results of scanning a document.
|
||||
type ScanResult struct {
|
||||
DocumentID string `json:"document_id"`
|
||||
Status DocReviewStatus `json:"status"`
|
||||
PIIFound []PIIMatch `json:"pii_found,omitempty"`
|
||||
SecretsFound []SecretMatch `json:"secrets_found,omitempty"`
|
||||
DataClass DataClassification `json:"data_classification"`
|
||||
ContentHash string `json:"content_hash"`
|
||||
ScannedAt time.Time `json:"scanned_at"`
|
||||
SizeBytes int `json:"size_bytes"`
|
||||
}
|
||||
|
||||
// PIIMatch represents a detected PII pattern in content.
|
||||
type PIIMatch struct {
|
||||
Type string `json:"type"` // "email", "phone", "ssn", "credit_card", "passport"
|
||||
Location int `json:"location"` // Character offset
|
||||
Length int `json:"length"`
|
||||
Masked string `json:"masked"` // Redacted value, e.g., "j***@example.com"
|
||||
}
|
||||
|
||||
// SecretMatch represents a detected secret/API key in content.
|
||||
type SecretMatch struct {
|
||||
Type string `json:"type"` // "api_key", "password", "token", "private_key"
|
||||
Location int `json:"location"`
|
||||
Length int `json:"length"`
|
||||
Provider string `json:"provider"` // "OpenAI", "AWS", "GitHub", etc.
|
||||
}
|
||||
|
||||
// DocBridge manages document scanning, redaction, and review workflow.
|
||||
type DocBridge struct {
|
||||
mu sync.RWMutex
|
||||
reviews map[string]*ScanResult
|
||||
piiPatterns []*piiPattern
|
||||
secretPats []secretPattern // Cached compiled patterns
|
||||
signatures *AISignatureDB // Reused across scans
|
||||
maxDocSize int // bytes
|
||||
}
|
||||
|
||||
type piiPattern struct {
|
||||
name string
|
||||
regex *regexp.Regexp
|
||||
maskFn func(string) string
|
||||
}
|
||||
|
||||
// NewDocBridge creates a new Document Review Bridge.
|
||||
func NewDocBridge() *DocBridge {
|
||||
return &DocBridge{
|
||||
reviews: make(map[string]*ScanResult),
|
||||
piiPatterns: defaultPIIPatterns(),
|
||||
secretPats: secretPatterns(),
|
||||
signatures: NewAISignatureDB(),
|
||||
maxDocSize: 10 * 1024 * 1024, // 10 MB
|
||||
}
|
||||
}
|
||||
|
||||
// ScanDocument scans content for PII and secrets, classifies data, returns result.
|
||||
func (db *DocBridge) ScanDocument(docID, content, userID string) *ScanResult {
|
||||
result := &ScanResult{
|
||||
DocumentID: docID,
|
||||
Status: DocReviewScanning,
|
||||
ScannedAt: time.Now(),
|
||||
SizeBytes: len(content),
|
||||
}
|
||||
|
||||
// Content hash for dedup.
|
||||
h := sha256.Sum256([]byte(content))
|
||||
result.ContentHash = fmt.Sprintf("%x", h[:])
|
||||
|
||||
// Size check.
|
||||
if len(content) > db.maxDocSize {
|
||||
result.Status = DocReviewBlocked
|
||||
result.DataClass = DataCritical
|
||||
db.store(result)
|
||||
return result
|
||||
}
|
||||
|
||||
// Scan for PII.
|
||||
result.PIIFound = db.scanPII(content)
|
||||
|
||||
// Scan for secrets (reuse cached signature DB).
|
||||
if keyType := db.signatures.ScanForAPIKeys(content); keyType != "" {
|
||||
result.SecretsFound = append(result.SecretsFound, SecretMatch{
|
||||
Type: "api_key",
|
||||
Provider: keyType,
|
||||
})
|
||||
}
|
||||
|
||||
// Scan for additional secret patterns.
|
||||
result.SecretsFound = append(result.SecretsFound, db.scanSecrets(content)...)
|
||||
|
||||
// Classify data based on findings.
|
||||
result.DataClass = db.classifyData(result)
|
||||
|
||||
// Set status based on findings.
|
||||
if len(result.SecretsFound) > 0 {
|
||||
result.Status = DocReviewBlocked
|
||||
} else if len(result.PIIFound) > 0 {
|
||||
result.Status = DocReviewRedacted
|
||||
} else {
|
||||
result.Status = DocReviewClean
|
||||
}
|
||||
|
||||
db.store(result)
|
||||
return result
|
||||
}
|
||||
|
||||
// RedactContent replaces PII and secrets in content with masked values.
|
||||
func (db *DocBridge) RedactContent(content string) string {
|
||||
for _, p := range db.piiPatterns {
|
||||
content = p.regex.ReplaceAllStringFunc(content, p.maskFn)
|
||||
}
|
||||
|
||||
// Redact common secret patterns (cached).
|
||||
for _, sp := range db.secretPats {
|
||||
content = sp.regex.ReplaceAllString(content, sp.replacement)
|
||||
}
|
||||
|
||||
return content
|
||||
}
|
||||
|
||||
// GetReview returns a scan result by document ID.
|
||||
func (db *DocBridge) GetReview(docID string) (*ScanResult, bool) {
|
||||
db.mu.RLock()
|
||||
defer db.mu.RUnlock()
|
||||
r, ok := db.reviews[docID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
cp := *r
|
||||
return &cp, true
|
||||
}
|
||||
|
||||
// RecentReviews returns the N most recent reviews.
|
||||
func (db *DocBridge) RecentReviews(limit int) []ScanResult {
|
||||
db.mu.RLock()
|
||||
defer db.mu.RUnlock()
|
||||
|
||||
results := make([]ScanResult, 0, len(db.reviews))
|
||||
for _, r := range db.reviews {
|
||||
results = append(results, *r)
|
||||
}
|
||||
|
||||
// Sort by time desc (simple bubble for bounded set).
|
||||
for i := 0; i < len(results); i++ {
|
||||
for j := i + 1; j < len(results); j++ {
|
||||
if results[j].ScannedAt.After(results[i].ScannedAt) {
|
||||
results[i], results[j] = results[j], results[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(results) > limit {
|
||||
results = results[:limit]
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
// Stats returns aggregate document review statistics.
|
||||
func (db *DocBridge) Stats() map[string]int {
|
||||
db.mu.RLock()
|
||||
defer db.mu.RUnlock()
|
||||
|
||||
stats := map[string]int{
|
||||
"total": len(db.reviews),
|
||||
"clean": 0,
|
||||
"redacted": 0,
|
||||
"blocked": 0,
|
||||
}
|
||||
for _, r := range db.reviews {
|
||||
switch r.Status {
|
||||
case DocReviewClean:
|
||||
stats["clean"]++
|
||||
case DocReviewRedacted:
|
||||
stats["redacted"]++
|
||||
case DocReviewBlocked:
|
||||
stats["blocked"]++
|
||||
}
|
||||
}
|
||||
return stats
|
||||
}
|
||||
|
||||
func (db *DocBridge) store(result *ScanResult) {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
db.reviews[result.DocumentID] = result
|
||||
}
|
||||
|
||||
// scanPII runs all PII patterns against content.
|
||||
func (db *DocBridge) scanPII(content string) []PIIMatch {
|
||||
var matches []PIIMatch
|
||||
for _, p := range db.piiPatterns {
|
||||
locs := p.regex.FindAllStringIndex(content, -1)
|
||||
for _, loc := range locs {
|
||||
matched := content[loc[0]:loc[1]]
|
||||
matches = append(matches, PIIMatch{
|
||||
Type: p.name,
|
||||
Location: loc[0],
|
||||
Length: loc[1] - loc[0],
|
||||
Masked: p.maskFn(matched),
|
||||
})
|
||||
}
|
||||
}
|
||||
return matches
|
||||
}
|
||||
|
||||
// scanSecrets scans for common secret patterns beyond AI API keys.
|
||||
func (db *DocBridge) scanSecrets(content string) []SecretMatch {
|
||||
var matches []SecretMatch
|
||||
for _, sp := range db.secretPats {
|
||||
locs := sp.regex.FindAllStringIndex(content, -1)
|
||||
for _, loc := range locs {
|
||||
matches = append(matches, SecretMatch{
|
||||
Type: sp.secretType,
|
||||
Location: loc[0],
|
||||
Length: loc[1] - loc[0],
|
||||
Provider: sp.provider,
|
||||
})
|
||||
}
|
||||
}
|
||||
return matches
|
||||
}
|
||||
|
||||
// classifyData determines the data classification level based on scan results.
|
||||
func (db *DocBridge) classifyData(result *ScanResult) DataClassification {
|
||||
if len(result.SecretsFound) > 0 {
|
||||
return DataCritical
|
||||
}
|
||||
|
||||
hasSensitivePII := false
|
||||
for _, pii := range result.PIIFound {
|
||||
switch pii.Type {
|
||||
case "ssn", "credit_card", "passport":
|
||||
return DataCritical
|
||||
case "email", "phone":
|
||||
hasSensitivePII = true
|
||||
}
|
||||
}
|
||||
|
||||
if hasSensitivePII {
|
||||
return DataConfidential
|
||||
}
|
||||
|
||||
if result.SizeBytes > 1024*1024 { // >1MB
|
||||
return DataInternal
|
||||
}
|
||||
|
||||
return DataPublic
|
||||
}
|
||||
|
||||
// --- PII Patterns ---
|
||||
|
||||
func defaultPIIPatterns() []*piiPattern {
|
||||
return []*piiPattern{
|
||||
{
|
||||
name: "email",
|
||||
regex: regexp.MustCompile(`[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}`),
|
||||
maskFn: func(s string) string {
|
||||
parts := strings.SplitN(s, "@", 2)
|
||||
if len(parts) != 2 {
|
||||
return "***@***"
|
||||
}
|
||||
if len(parts[0]) <= 1 {
|
||||
return "*@" + parts[1]
|
||||
}
|
||||
return string(parts[0][0]) + "***@" + parts[1]
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "phone",
|
||||
regex: regexp.MustCompile(`\+?[1-9]\d{0,2}[\s\-]?\(?\d{3}\)?[\s\-]?\d{3}[\s\-]?\d{2,4}`),
|
||||
maskFn: func(s string) string {
|
||||
if len(s) < 4 {
|
||||
return "***"
|
||||
}
|
||||
return s[:2] + strings.Repeat("*", len(s)-4) + s[len(s)-2:]
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ssn",
|
||||
regex: regexp.MustCompile(`\b\d{3}-\d{2}-\d{4}\b`),
|
||||
maskFn: func(_ string) string {
|
||||
return "***-**-****"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "credit_card",
|
||||
regex: regexp.MustCompile(`\b(?:\d{4}[\s\-]?){3}\d{4}\b`),
|
||||
maskFn: func(s string) string {
|
||||
clean := strings.ReplaceAll(strings.ReplaceAll(s, "-", ""), " ", "")
|
||||
if len(clean) < 4 {
|
||||
return "****"
|
||||
}
|
||||
return strings.Repeat("*", len(clean)-4) + clean[len(clean)-4:]
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "passport",
|
||||
regex: regexp.MustCompile(`\b[A-Z]{1,2}\d{6,9}\b`),
|
||||
maskFn: func(s string) string {
|
||||
if len(s) <= 2 {
|
||||
return "**"
|
||||
}
|
||||
return s[:2] + strings.Repeat("*", len(s)-2)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type secretPattern struct {
|
||||
secretType string
|
||||
provider string
|
||||
regex *regexp.Regexp
|
||||
replacement string
|
||||
}
|
||||
|
||||
func secretPatterns() []secretPattern {
|
||||
return []secretPattern{
|
||||
{secretType: "aws_key", provider: "AWS", regex: regexp.MustCompile(`AKIA[0-9A-Z]{16}`), replacement: "[AWS_KEY_REDACTED]"},
|
||||
{secretType: "github_token", provider: "GitHub", regex: regexp.MustCompile(`ghp_[a-zA-Z0-9]{36}`), replacement: "[GITHUB_TOKEN_REDACTED]"},
|
||||
{secretType: "github_token", provider: "GitHub", regex: regexp.MustCompile(`github_pat_[a-zA-Z0-9_]{82}`), replacement: "[GITHUB_PAT_REDACTED]"},
|
||||
{secretType: "slack_token", provider: "Slack", regex: regexp.MustCompile(`xoxb-[0-9]{10,13}-[0-9]{10,13}-[a-zA-Z0-9]{24}`), replacement: "[SLACK_TOKEN_REDACTED]"},
|
||||
{secretType: "private_key", provider: "Generic", regex: regexp.MustCompile(`-----BEGIN (?:RSA |EC |DSA )?PRIVATE KEY-----`), replacement: "[PRIVATE_KEY_REDACTED]"},
|
||||
{secretType: "password", provider: "Generic", regex: regexp.MustCompile(`(?i)password\s*[=:]\s*['"]?[^\s'"]{8,}`), replacement: "[PASSWORD_REDACTED]"},
|
||||
{secretType: "connection_string", provider: "Database", regex: regexp.MustCompile(`(?i)(?:mysql|postgres|mongodb)://[^\s]+`), replacement: "[DB_CONN_REDACTED]"},
|
||||
}
|
||||
}
|
||||
148
internal/application/shadow_ai/fallback.go
Normal file
148
internal/application/shadow_ai/fallback.go
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
package shadow_ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FallbackManager provides priority-based enforcement with graceful degradation.
|
||||
// Tries enforcement points in priority order; falls back to detect_only if all are offline.
|
||||
type FallbackManager struct {
|
||||
registry *PluginRegistry
|
||||
priority []PluginType // e.g., ["proxy", "firewall", "edr"]
|
||||
strategy string // "detect_only" | "alert_only"
|
||||
logger *slog.Logger
|
||||
|
||||
// Event logging for detect-only fallback.
|
||||
eventLogFn func(event ShadowAIEvent)
|
||||
}
|
||||
|
||||
// NewFallbackManager creates a new fallback manager with the given enforcement priority.
|
||||
func NewFallbackManager(registry *PluginRegistry, strategy string) *FallbackManager {
|
||||
if strategy == "" {
|
||||
strategy = "detect_only"
|
||||
}
|
||||
return &FallbackManager{
|
||||
registry: registry,
|
||||
priority: []PluginType{PluginTypeProxy, PluginTypeFirewall, PluginTypeEDR},
|
||||
strategy: strategy,
|
||||
logger: slog.Default().With("component", "shadow-ai-fallback"),
|
||||
}
|
||||
}
|
||||
|
||||
// SetEventLogger sets the callback for logging detection-only events.
|
||||
func (fm *FallbackManager) SetEventLogger(fn func(ShadowAIEvent)) {
|
||||
fm.eventLogFn = fn
|
||||
}
|
||||
|
||||
// BlockDomain attempts to block a domain using the highest-priority healthy plugin.
|
||||
// Returns the vendor that enforced, or falls back to detect_only mode.
|
||||
func (fm *FallbackManager) BlockDomain(ctx context.Context, domain, reason string) (enforcedBy string, err error) {
|
||||
for _, pType := range fm.priority {
|
||||
plugins := fm.registry.GetByType(pType)
|
||||
for _, plugin := range plugins {
|
||||
ne, ok := plugin.(NetworkEnforcer)
|
||||
if !ok {
|
||||
// Try WebGateway for URL-based blocking.
|
||||
if wg, ok := plugin.(WebGateway); ok {
|
||||
vendor := wg.Vendor()
|
||||
if !fm.registry.IsHealthy(vendor) {
|
||||
continue
|
||||
}
|
||||
if err := wg.BlockURL(ctx, domain, reason); err != nil {
|
||||
fm.logger.Warn("block failed on gateway", "vendor", vendor, "error", err)
|
||||
continue
|
||||
}
|
||||
return vendor, nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
vendor := ne.Vendor()
|
||||
if !fm.registry.IsHealthy(vendor) {
|
||||
continue
|
||||
}
|
||||
if err := ne.BlockDomain(ctx, domain, reason); err != nil {
|
||||
fm.logger.Warn("block failed on enforcer", "vendor", vendor, "error", err)
|
||||
continue
|
||||
}
|
||||
return vendor, nil
|
||||
}
|
||||
}
|
||||
|
||||
// All enforcement points unavailable — fallback.
|
||||
fm.logger.Warn("all enforcement points unavailable, falling to detect_only",
|
||||
"domain", domain,
|
||||
"strategy", fm.strategy,
|
||||
)
|
||||
fm.logDetectOnly(domain, reason)
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// BlockIP attempts to block an IP using the highest-priority healthy firewall.
|
||||
func (fm *FallbackManager) BlockIP(ctx context.Context, ip string, duration time.Duration, reason string) (enforcedBy string, err error) {
|
||||
enforcers := fm.registry.GetNetworkEnforcers()
|
||||
for _, ne := range enforcers {
|
||||
vendor := ne.Vendor()
|
||||
if !fm.registry.IsHealthy(vendor) {
|
||||
continue
|
||||
}
|
||||
if err := ne.BlockIP(ctx, ip, duration, reason); err != nil {
|
||||
fm.logger.Warn("block IP failed", "vendor", vendor, "error", err)
|
||||
continue
|
||||
}
|
||||
return vendor, nil
|
||||
}
|
||||
|
||||
fm.logger.Warn("no healthy enforcer for IP block, falling to detect_only",
|
||||
"ip", ip,
|
||||
"strategy", fm.strategy,
|
||||
)
|
||||
fm.logDetectOnly(ip, reason)
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// IsolateHost attempts to isolate a host using the highest-priority healthy EDR.
|
||||
func (fm *FallbackManager) IsolateHost(ctx context.Context, hostname string) (enforcedBy string, err error) {
|
||||
controllers := fm.registry.GetEndpointControllers()
|
||||
for _, ec := range controllers {
|
||||
vendor := ec.Vendor()
|
||||
if !fm.registry.IsHealthy(vendor) {
|
||||
continue
|
||||
}
|
||||
if err := ec.IsolateHost(ctx, hostname); err != nil {
|
||||
fm.logger.Warn("isolate failed", "vendor", vendor, "error", err)
|
||||
continue
|
||||
}
|
||||
return vendor, nil
|
||||
}
|
||||
|
||||
fm.logger.Warn("no healthy EDR for host isolation, falling to detect_only",
|
||||
"hostname", hostname,
|
||||
"strategy", fm.strategy,
|
||||
)
|
||||
return "", fmt.Errorf("no healthy EDR available for host isolation")
|
||||
}
|
||||
|
||||
// logDetectOnly records a detection-only event when no enforcement is possible.
|
||||
func (fm *FallbackManager) logDetectOnly(target, reason string) {
|
||||
if fm.eventLogFn != nil {
|
||||
fm.eventLogFn(ShadowAIEvent{
|
||||
Destination: target,
|
||||
DetectionMethod: DetectNetwork,
|
||||
Action: "detect_only",
|
||||
Metadata: map[string]string{
|
||||
"reason": reason,
|
||||
"fallback_strategy": fm.strategy,
|
||||
},
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Strategy returns the configured fallback strategy.
|
||||
func (fm *FallbackManager) Strategy() string {
|
||||
return fm.strategy
|
||||
}
|
||||
163
internal/application/shadow_ai/health.go
Normal file
163
internal/application/shadow_ai/health.go
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
package shadow_ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PluginStatus represents a plugin's operational state.
|
||||
type PluginStatus string
|
||||
|
||||
const (
|
||||
PluginStatusHealthy PluginStatus = "healthy"
|
||||
PluginStatusDegraded PluginStatus = "degraded"
|
||||
PluginStatusOffline PluginStatus = "offline"
|
||||
)
|
||||
|
||||
// PluginHealth tracks the health state of a single plugin.
|
||||
type PluginHealth struct {
|
||||
Vendor string `json:"vendor"`
|
||||
Type PluginType `json:"type"`
|
||||
Status PluginStatus `json:"status"`
|
||||
LastCheck time.Time `json:"last_check"`
|
||||
Consecutive int `json:"consecutive_failures"`
|
||||
Latency time.Duration `json:"latency"`
|
||||
LastError string `json:"last_error,omitempty"`
|
||||
}
|
||||
|
||||
// MaxConsecutivePluginFailures before marking offline.
|
||||
const MaxConsecutivePluginFailures = 3
|
||||
|
||||
// HealthChecker performs continuous health monitoring of all registered plugins.
|
||||
type HealthChecker struct {
|
||||
mu sync.RWMutex
|
||||
registry *PluginRegistry
|
||||
interval time.Duration
|
||||
alertFn func(vendor string, status PluginStatus, msg string)
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewHealthChecker creates a health checker that monitors plugin health.
|
||||
func NewHealthChecker(registry *PluginRegistry, interval time.Duration, alertFn func(string, PluginStatus, string)) *HealthChecker {
|
||||
if interval <= 0 {
|
||||
interval = 30 * time.Second
|
||||
}
|
||||
return &HealthChecker{
|
||||
registry: registry,
|
||||
interval: interval,
|
||||
alertFn: alertFn,
|
||||
logger: slog.Default().With("component", "shadow-ai-health"),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins continuous health monitoring. Blocks until ctx is cancelled.
|
||||
func (hc *HealthChecker) Start(ctx context.Context) {
|
||||
hc.logger.Info("health checker started", "interval", hc.interval)
|
||||
ticker := time.NewTicker(hc.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
hc.logger.Info("health checker stopped")
|
||||
return
|
||||
case <-ticker.C:
|
||||
hc.checkAllPlugins(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkAllPlugins runs health checks on all registered plugins.
|
||||
func (hc *HealthChecker) checkAllPlugins(ctx context.Context) {
|
||||
vendors := hc.registry.Vendors()
|
||||
|
||||
for _, vendor := range vendors {
|
||||
plugin, ok := hc.registry.Get(vendor)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
existing, _ := hc.registry.GetHealth(vendor)
|
||||
if existing == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
err := hc.checkPlugin(ctx, plugin)
|
||||
latency := time.Since(start)
|
||||
|
||||
health := &PluginHealth{
|
||||
Vendor: vendor,
|
||||
Type: existing.Type,
|
||||
LastCheck: time.Now(),
|
||||
Latency: latency,
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
health.Consecutive = existing.Consecutive + 1
|
||||
health.LastError = err.Error()
|
||||
|
||||
if health.Consecutive >= MaxConsecutivePluginFailures {
|
||||
health.Status = PluginStatusOffline
|
||||
if existing.Status != PluginStatusOffline {
|
||||
hc.logger.Error("plugin went OFFLINE",
|
||||
"vendor", vendor,
|
||||
"consecutive", health.Consecutive,
|
||||
"error", err,
|
||||
)
|
||||
if hc.alertFn != nil {
|
||||
hc.alertFn(vendor, PluginStatusOffline,
|
||||
fmt.Sprintf("Plugin %s offline after %d consecutive failures: %v",
|
||||
vendor, health.Consecutive, err))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
health.Status = PluginStatusDegraded
|
||||
hc.logger.Warn("plugin health check failed",
|
||||
"vendor", vendor,
|
||||
"consecutive", health.Consecutive,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
health.Status = PluginStatusHealthy
|
||||
health.Consecutive = 0
|
||||
|
||||
// Log recovery if previously degraded/offline.
|
||||
if existing.Status != PluginStatusHealthy {
|
||||
hc.logger.Info("plugin recovered", "vendor", vendor, "latency", latency)
|
||||
if hc.alertFn != nil {
|
||||
hc.alertFn(vendor, PluginStatusHealthy,
|
||||
fmt.Sprintf("Plugin %s recovered, latency %s", vendor, latency))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hc.registry.SetHealth(vendor, health)
|
||||
}
|
||||
}
|
||||
|
||||
// checkPlugin runs the health check for a single plugin.
|
||||
func (hc *HealthChecker) checkPlugin(ctx context.Context, plugin interface{}) error {
|
||||
checkCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
switch p := plugin.(type) {
|
||||
case NetworkEnforcer:
|
||||
return p.HealthCheck(checkCtx)
|
||||
case EndpointController:
|
||||
return p.HealthCheck(checkCtx)
|
||||
case WebGateway:
|
||||
return p.HealthCheck(checkCtx)
|
||||
default:
|
||||
return fmt.Errorf("plugin does not implement HealthCheck")
|
||||
}
|
||||
}
|
||||
|
||||
// CheckNow runs an immediate health check on all plugins (non-blocking).
|
||||
func (hc *HealthChecker) CheckNow(ctx context.Context) {
|
||||
hc.checkAllPlugins(ctx)
|
||||
}
|
||||
225
internal/application/shadow_ai/interfaces.go
Normal file
225
internal/application/shadow_ai/interfaces.go
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
// Package shadow_ai implements the Sentinel Shadow AI Control Module.
|
||||
//
|
||||
// Five levels of shadow AI management:
|
||||
//
|
||||
// L1 — Universal Integration Layer: plugin-based enforcement (firewall, EDR, proxy)
|
||||
// L2 — Detection Engine: network signatures, endpoint, API keys, behavioral
|
||||
// L3 — Document Review Bridge: controlled LLM access with PII/secret scanning
|
||||
// L4 — Approval Workflow: tiered data classification and manager/SOC approval
|
||||
// L5 — SOC Integration: dashboard, correlation rules, playbooks, compliance
|
||||
package shadow_ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- Plugin Interfaces ---
|
||||
|
||||
// NetworkEnforcer is the universal interface for ALL firewalls.
|
||||
// Implementations: Check Point, Cisco ASA/FMC, Palo Alto, Fortinet.
|
||||
type NetworkEnforcer interface {
|
||||
// BlockIP blocks an IP address for the given duration.
|
||||
BlockIP(ctx context.Context, ip string, duration time.Duration, reason string) error
|
||||
|
||||
// BlockDomain blocks a domain name.
|
||||
BlockDomain(ctx context.Context, domain string, reason string) error
|
||||
|
||||
// UnblockIP removes an IP block.
|
||||
UnblockIP(ctx context.Context, ip string) error
|
||||
|
||||
// UnblockDomain removes a domain block.
|
||||
UnblockDomain(ctx context.Context, domain string) error
|
||||
|
||||
// HealthCheck verifies the firewall API is reachable.
|
||||
HealthCheck(ctx context.Context) error
|
||||
|
||||
// Vendor returns the vendor identifier (e.g., "checkpoint", "cisco", "paloalto").
|
||||
Vendor() string
|
||||
}
|
||||
|
||||
// EndpointController is the universal interface for ALL EDR systems.
|
||||
// Implementations: CrowdStrike, SentinelOne, Microsoft Defender.
|
||||
type EndpointController interface {
|
||||
// IsolateHost quarantines a host from the network.
|
||||
IsolateHost(ctx context.Context, hostname string) error
|
||||
|
||||
// ReleaseHost removes host isolation.
|
||||
ReleaseHost(ctx context.Context, hostname string) error
|
||||
|
||||
// KillProcess terminates a process on a remote host.
|
||||
KillProcess(ctx context.Context, hostname string, pid int) error
|
||||
|
||||
// QuarantineFile moves a file to quarantine on a remote host.
|
||||
QuarantineFile(ctx context.Context, hostname string, path string) error
|
||||
|
||||
// HealthCheck verifies the EDR API is reachable.
|
||||
HealthCheck(ctx context.Context) error
|
||||
|
||||
// Vendor returns the vendor identifier (e.g., "crowdstrike", "sentinelone", "defender").
|
||||
Vendor() string
|
||||
}
|
||||
|
||||
// WebGateway is the universal interface for ALL proxy/CASB systems.
|
||||
// Implementations: Zscaler, Netskope, Squid, BlueCoat.
|
||||
type WebGateway interface {
|
||||
// BlockURL adds a URL to the blocklist.
|
||||
BlockURL(ctx context.Context, url string, reason string) error
|
||||
|
||||
// UnblockURL removes a URL from the blocklist.
|
||||
UnblockURL(ctx context.Context, url string) error
|
||||
|
||||
// BlockCategory blocks an entire URL category (e.g., "Artificial Intelligence").
|
||||
BlockCategory(ctx context.Context, category string) error
|
||||
|
||||
// HealthCheck verifies the gateway API is reachable.
|
||||
HealthCheck(ctx context.Context) error
|
||||
|
||||
// Vendor returns the vendor identifier (e.g., "zscaler", "netskope", "squid").
|
||||
Vendor() string
|
||||
}
|
||||
|
||||
// Initializer is implemented by plugins that need configuration before use.
|
||||
type Initializer interface {
|
||||
Initialize(config map[string]interface{}) error
|
||||
}
|
||||
|
||||
// --- Plugin Configuration ---
|
||||
|
||||
// PluginType categorizes enforcement points.
|
||||
type PluginType string
|
||||
|
||||
const (
|
||||
PluginTypeFirewall PluginType = "firewall"
|
||||
PluginTypeEDR PluginType = "edr"
|
||||
PluginTypeProxy PluginType = "proxy"
|
||||
PluginTypeDNS PluginType = "dns"
|
||||
)
|
||||
|
||||
// PluginConfig defines a vendor plugin configuration loaded from YAML.
|
||||
type PluginConfig struct {
|
||||
Type PluginType `yaml:"type" json:"type"`
|
||||
Vendor string `yaml:"vendor" json:"vendor"`
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
Config map[string]interface{} `yaml:"config" json:"config"`
|
||||
}
|
||||
|
||||
// IntegrationConfig is the top-level Shadow AI configuration.
|
||||
type IntegrationConfig struct {
|
||||
Plugins []PluginConfig `yaml:"plugins" json:"plugins"`
|
||||
FallbackStrategy string `yaml:"fallback_strategy" json:"fallback_strategy"` // "detect_only" | "alert_only"
|
||||
HealthCheckInterval time.Duration `yaml:"health_check_interval" json:"health_check_interval"` // default: 30s
|
||||
}
|
||||
|
||||
// --- Domain Types ---
|
||||
|
||||
// DetectionMethod identifies how a shadow AI usage was detected.
|
||||
type DetectionMethod string
|
||||
|
||||
const (
|
||||
DetectNetwork DetectionMethod = "network" // Domain/IP match
|
||||
DetectHTTP DetectionMethod = "http" // HTTP header signature
|
||||
DetectTLS DetectionMethod = "tls" // TLS/JA3 fingerprint
|
||||
DetectProcess DetectionMethod = "process" // AI tool process execution
|
||||
DetectAPIKey DetectionMethod = "api_key" // AI API key in payload
|
||||
DetectBehavioral DetectionMethod = "behavioral" // Anomalous AI access pattern
|
||||
DetectClipboard DetectionMethod = "clipboard" // Large clipboard → AI browser pattern
|
||||
)
|
||||
|
||||
// DataClassification determines the approval tier required.
|
||||
type DataClassification string
|
||||
|
||||
const (
|
||||
DataPublic DataClassification = "PUBLIC"
|
||||
DataInternal DataClassification = "INTERNAL"
|
||||
DataConfidential DataClassification = "CONFIDENTIAL"
|
||||
DataCritical DataClassification = "CRITICAL"
|
||||
)
|
||||
|
||||
// ShadowAIEvent is a detected shadow AI usage attempt.
|
||||
type ShadowAIEvent struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Hostname string `json:"hostname"`
|
||||
Destination string `json:"destination"` // Target AI service domain/IP
|
||||
AIService string `json:"ai_service"` // "chatgpt", "claude", "gemini", etc.
|
||||
DetectionMethod DetectionMethod `json:"detection_method"`
|
||||
Action string `json:"action"` // "blocked", "allowed", "pending"
|
||||
EnforcedBy string `json:"enforced_by"` // Plugin vendor that enforced
|
||||
DataSize int64 `json:"data_size"` // Bytes sent to AI
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// AIServiceInfo describes a known AI service for signature matching.
|
||||
type AIServiceInfo struct {
|
||||
Name string `json:"name"` // "ChatGPT", "Claude", "Gemini"
|
||||
Vendor string `json:"vendor"` // "OpenAI", "Anthropic", "Google"
|
||||
Domains []string `json:"domains"` // ["*.openai.com", "chat.openai.com"]
|
||||
Category string `json:"category"` // "llm", "image_gen", "code_assist"
|
||||
}
|
||||
|
||||
// BlockRequest is an API request to manually block a target.
|
||||
type BlockRequest struct {
|
||||
TargetType string `json:"target_type"` // "ip", "domain", "user"
|
||||
Target string `json:"target"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
Reason string `json:"reason"`
|
||||
BlockedBy string `json:"blocked_by"` // RBAC user
|
||||
}
|
||||
|
||||
// ShadowAIStats provides aggregate statistics for the dashboard.
|
||||
type ShadowAIStats struct {
|
||||
TimeRange string `json:"time_range"` // "24h", "7d", "30d"
|
||||
Total int `json:"total_attempts"`
|
||||
Blocked int `json:"blocked"`
|
||||
Approved int `json:"approved"`
|
||||
Pending int `json:"pending"`
|
||||
ByService map[string]int `json:"by_service"`
|
||||
ByDepartment map[string]int `json:"by_department"`
|
||||
TopViolators []Violator `json:"top_violators"`
|
||||
}
|
||||
|
||||
// Violator tracks a user's shadow AI violation count.
|
||||
type Violator struct {
|
||||
UserID string `json:"user_id"`
|
||||
Attempts int `json:"attempts"`
|
||||
}
|
||||
|
||||
// ApprovalTier defines the approval requirements for a data classification level.
|
||||
type ApprovalTier struct {
|
||||
Name string `yaml:"name" json:"name"`
|
||||
DataClass DataClassification `yaml:"data_class" json:"data_class"`
|
||||
ApprovalNeeded []string `yaml:"approval_needed" json:"approval_needed"` // ["manager"], ["manager", "soc"], ["ciso"]
|
||||
SLA time.Duration `yaml:"sla" json:"sla"`
|
||||
AutoApprove bool `yaml:"auto_approve" json:"auto_approve"`
|
||||
}
|
||||
|
||||
// ApprovalRequest tracks a pending approval for AI access.
|
||||
type ApprovalRequest struct {
|
||||
ID string `json:"id"`
|
||||
DocID string `json:"doc_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Tier string `json:"tier"`
|
||||
DataClass DataClassification `json:"data_class"`
|
||||
Status string `json:"status"` // "pending", "approved", "denied", "expired"
|
||||
ApprovedBy string `json:"approved_by,omitempty"`
|
||||
DeniedBy string `json:"denied_by,omitempty"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
ResolvedAt time.Time `json:"resolved_at,omitempty"`
|
||||
}
|
||||
|
||||
// ComplianceReport is the Shadow AI compliance report for GDPR/SOC2/EU AI Act.
|
||||
type ComplianceReport struct {
|
||||
GeneratedAt time.Time `json:"generated_at"`
|
||||
Period string `json:"period"` // "monthly", "quarterly"
|
||||
TotalInteractions int `json:"total_interactions"`
|
||||
BlockedAttempts int `json:"blocked_attempts"`
|
||||
ApprovedReviews int `json:"approved_reviews"`
|
||||
PIIDetected int `json:"pii_detected"`
|
||||
SecretsDetected int `json:"secrets_detected"`
|
||||
AuditComplete bool `json:"audit_complete"`
|
||||
Regulations []string `json:"regulations"` // ["GDPR", "SOC2", "EU AI Act"]
|
||||
}
|
||||
212
internal/application/shadow_ai/plugins.go
Normal file
212
internal/application/shadow_ai/plugins.go
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
package shadow_ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- Vendor Plugin Stubs ---
|
||||
// Reference implementations for major security vendors.
|
||||
// These stubs implement the full interface with logging but no real API calls.
|
||||
// Production deployments replace these with real vendor SDK integrations.
|
||||
|
||||
// CheckPointEnforcer is a stub implementation for Check Point firewalls.
|
||||
type CheckPointEnforcer struct {
|
||||
apiURL string
|
||||
apiKey string
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewCheckPointEnforcer() *CheckPointEnforcer {
|
||||
return &CheckPointEnforcer{
|
||||
logger: slog.Default().With("component", "shadow-ai-plugin-checkpoint"),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CheckPointEnforcer) Initialize(config map[string]interface{}) error {
|
||||
if url, ok := config["api_url"].(string); ok {
|
||||
c.apiURL = url
|
||||
}
|
||||
if key, ok := config["api_key"].(string); ok {
|
||||
c.apiKey = key
|
||||
}
|
||||
if c.apiURL == "" {
|
||||
return fmt.Errorf("checkpoint: api_url required")
|
||||
}
|
||||
c.logger.Info("initialized", "api_url", c.apiURL)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CheckPointEnforcer) BlockIP(_ context.Context, ip string, duration time.Duration, reason string) error {
|
||||
c.logger.Info("block IP", "ip", ip, "duration", duration, "reason", reason)
|
||||
// Stub: would call Check Point Management API POST /web_api/add-host
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CheckPointEnforcer) BlockDomain(_ context.Context, domain string, reason string) error {
|
||||
c.logger.Info("block domain", "domain", domain, "reason", reason)
|
||||
// Stub: would create application-site-category block rule
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CheckPointEnforcer) UnblockIP(_ context.Context, ip string) error {
|
||||
c.logger.Info("unblock IP", "ip", ip)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CheckPointEnforcer) UnblockDomain(_ context.Context, domain string) error {
|
||||
c.logger.Info("unblock domain", "domain", domain)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CheckPointEnforcer) HealthCheck(ctx context.Context) error {
|
||||
if c.apiURL == "" {
|
||||
return fmt.Errorf("not configured")
|
||||
}
|
||||
// Stub: would call GET /web_api/show-session
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CheckPointEnforcer) Vendor() string { return "checkpoint" }
|
||||
|
||||
// CrowdStrikeController is a stub implementation for CrowdStrike Falcon EDR.
|
||||
type CrowdStrikeController struct {
|
||||
clientID string
|
||||
clientSecret string
|
||||
baseURL string
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewCrowdStrikeController() *CrowdStrikeController {
|
||||
return &CrowdStrikeController{
|
||||
baseURL: "https://api.crowdstrike.com",
|
||||
logger: slog.Default().With("component", "shadow-ai-plugin-crowdstrike"),
|
||||
}
|
||||
}
|
||||
|
||||
func (cs *CrowdStrikeController) Initialize(config map[string]interface{}) error {
|
||||
if id, ok := config["client_id"].(string); ok {
|
||||
cs.clientID = id
|
||||
}
|
||||
if secret, ok := config["client_secret"].(string); ok {
|
||||
cs.clientSecret = secret
|
||||
}
|
||||
if url, ok := config["base_url"].(string); ok {
|
||||
cs.baseURL = url
|
||||
}
|
||||
if cs.clientID == "" {
|
||||
return fmt.Errorf("crowdstrike: client_id required")
|
||||
}
|
||||
cs.logger.Info("initialized", "base_url", cs.baseURL)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cs *CrowdStrikeController) IsolateHost(_ context.Context, hostname string) error {
|
||||
cs.logger.Info("isolate host", "hostname", hostname)
|
||||
// Stub: would call POST /devices/entities/devices-actions/v2?action_name=contain
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cs *CrowdStrikeController) ReleaseHost(_ context.Context, hostname string) error {
|
||||
cs.logger.Info("release host", "hostname", hostname)
|
||||
// Stub: would call POST /devices/entities/devices-actions/v2?action_name=lift_containment
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cs *CrowdStrikeController) KillProcess(_ context.Context, hostname string, pid int) error {
|
||||
cs.logger.Info("kill process", "hostname", hostname, "pid", pid)
|
||||
// Stub: would use RTR session to kill process
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cs *CrowdStrikeController) QuarantineFile(_ context.Context, hostname, path string) error {
|
||||
cs.logger.Info("quarantine file", "hostname", hostname, "path", path)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cs *CrowdStrikeController) HealthCheck(ctx context.Context) error {
|
||||
if cs.clientID == "" {
|
||||
return fmt.Errorf("not configured")
|
||||
}
|
||||
// Stub: would call GET /sensors/queries/sensors/v1?limit=1
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cs *CrowdStrikeController) Vendor() string { return "crowdstrike" }
|
||||
|
||||
// ZscalerGateway is a stub implementation for Zscaler Internet Access.
|
||||
type ZscalerGateway struct {
|
||||
cloudName string
|
||||
apiKey string
|
||||
username string
|
||||
password string
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewZscalerGateway() *ZscalerGateway {
|
||||
return &ZscalerGateway{
|
||||
logger: slog.Default().With("component", "shadow-ai-plugin-zscaler"),
|
||||
}
|
||||
}
|
||||
|
||||
func (z *ZscalerGateway) Initialize(config map[string]interface{}) error {
|
||||
if cloud, ok := config["cloud_name"].(string); ok {
|
||||
z.cloudName = cloud
|
||||
}
|
||||
if key, ok := config["api_key"].(string); ok {
|
||||
z.apiKey = key
|
||||
}
|
||||
if user, ok := config["username"].(string); ok {
|
||||
z.username = user
|
||||
}
|
||||
if pass, ok := config["password"].(string); ok {
|
||||
z.password = pass
|
||||
}
|
||||
if z.cloudName == "" {
|
||||
return fmt.Errorf("zscaler: cloud_name required")
|
||||
}
|
||||
z.logger.Info("initialized", "cloud", z.cloudName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (z *ZscalerGateway) BlockURL(_ context.Context, url, reason string) error {
|
||||
z.logger.Info("block URL", "url", url, "reason", reason)
|
||||
// Stub: would call PUT /webApplicationRules to add URL to block list
|
||||
return nil
|
||||
}
|
||||
|
||||
func (z *ZscalerGateway) UnblockURL(_ context.Context, url string) error {
|
||||
z.logger.Info("unblock URL", "url", url)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (z *ZscalerGateway) BlockCategory(_ context.Context, category string) error {
|
||||
z.logger.Info("block category", "category", category)
|
||||
// Stub: would update URL category policy to BLOCK
|
||||
return nil
|
||||
}
|
||||
|
||||
func (z *ZscalerGateway) HealthCheck(ctx context.Context) error {
|
||||
if z.cloudName == "" {
|
||||
return fmt.Errorf("not configured")
|
||||
}
|
||||
// Stub: would call GET /status
|
||||
return nil
|
||||
}
|
||||
|
||||
func (z *ZscalerGateway) Vendor() string { return "zscaler" }
|
||||
|
||||
// RegisterDefaultPlugins registers all built-in vendor plugin factories.
|
||||
func RegisterDefaultPlugins(registry *PluginRegistry) {
|
||||
registry.RegisterFactory(PluginTypeFirewall, "checkpoint", func() interface{} {
|
||||
return NewCheckPointEnforcer()
|
||||
})
|
||||
registry.RegisterFactory(PluginTypeEDR, "crowdstrike", func() interface{} {
|
||||
return NewCrowdStrikeController()
|
||||
})
|
||||
registry.RegisterFactory(PluginTypeProxy, "zscaler", func() interface{} {
|
||||
return NewZscalerGateway()
|
||||
})
|
||||
}
|
||||
212
internal/application/shadow_ai/registry.go
Normal file
212
internal/application/shadow_ai/registry.go
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
package shadow_ai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// PluginFactory creates a new plugin instance.
|
||||
type PluginFactory func() interface{}
|
||||
|
||||
// PluginRegistry manages vendor plugin registration, loading, and lifecycle.
|
||||
// Thread-safe via sync.RWMutex.
|
||||
type PluginRegistry struct {
|
||||
mu sync.RWMutex
|
||||
plugins map[string]interface{} // vendor → plugin instance
|
||||
factories map[string]PluginFactory // "type_vendor" → factory
|
||||
configs map[string]*PluginConfig // vendor → config
|
||||
health map[string]*PluginHealth // vendor → health status
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewPluginRegistry creates a new plugin registry.
|
||||
func NewPluginRegistry() *PluginRegistry {
|
||||
return &PluginRegistry{
|
||||
plugins: make(map[string]interface{}),
|
||||
factories: make(map[string]PluginFactory),
|
||||
configs: make(map[string]*PluginConfig),
|
||||
health: make(map[string]*PluginHealth),
|
||||
logger: slog.Default().With("component", "shadow-ai-registry"),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterFactory registers a plugin factory for a given type+vendor combination.
|
||||
// Example: RegisterFactory("firewall", "checkpoint", func() interface{} { return &CheckPointEnforcer{} })
|
||||
func (r *PluginRegistry) RegisterFactory(pluginType PluginType, vendor string, factory PluginFactory) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
key := fmt.Sprintf("%s_%s", pluginType, vendor)
|
||||
r.factories[key] = factory
|
||||
r.logger.Info("factory registered", "type", pluginType, "vendor", vendor)
|
||||
}
|
||||
|
||||
// LoadPlugins creates and initializes plugins from configuration.
|
||||
// Plugins that fail to initialize are logged but do not block other plugins.
|
||||
func (r *PluginRegistry) LoadPlugins(config *IntegrationConfig) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
loaded := 0
|
||||
for i := range config.Plugins {
|
||||
pluginCfg := &config.Plugins[i]
|
||||
if !pluginCfg.Enabled {
|
||||
r.logger.Debug("plugin disabled, skipping", "vendor", pluginCfg.Vendor)
|
||||
continue
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s_%s", pluginCfg.Type, pluginCfg.Vendor)
|
||||
factory, exists := r.factories[key]
|
||||
if !exists {
|
||||
r.logger.Warn("no factory for plugin", "key", key, "vendor", pluginCfg.Vendor)
|
||||
continue
|
||||
}
|
||||
|
||||
plugin := factory()
|
||||
|
||||
// Initialize if plugin supports it.
|
||||
if init, ok := plugin.(Initializer); ok {
|
||||
if err := init.Initialize(pluginCfg.Config); err != nil {
|
||||
r.logger.Error("plugin init failed", "vendor", pluginCfg.Vendor, "error", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
r.plugins[pluginCfg.Vendor] = plugin
|
||||
r.configs[pluginCfg.Vendor] = pluginCfg
|
||||
r.health[pluginCfg.Vendor] = &PluginHealth{
|
||||
Vendor: pluginCfg.Vendor,
|
||||
Type: pluginCfg.Type,
|
||||
Status: PluginStatusHealthy,
|
||||
}
|
||||
loaded++
|
||||
r.logger.Info("plugin loaded", "vendor", pluginCfg.Vendor, "type", pluginCfg.Type)
|
||||
}
|
||||
|
||||
r.logger.Info("plugin loading complete", "loaded", loaded, "total", len(config.Plugins))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns a plugin by vendor name.
|
||||
func (r *PluginRegistry) Get(vendor string) (interface{}, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
p, ok := r.plugins[vendor]
|
||||
return p, ok
|
||||
}
|
||||
|
||||
// GetByType returns all plugins of a given type.
|
||||
func (r *PluginRegistry) GetByType(pluginType PluginType) []interface{} {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
var result []interface{}
|
||||
for vendor, cfg := range r.configs {
|
||||
if cfg.Type == pluginType {
|
||||
if plugin, ok := r.plugins[vendor]; ok {
|
||||
result = append(result, plugin)
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetNetworkEnforcers returns all loaded NetworkEnforcer plugins.
|
||||
func (r *PluginRegistry) GetNetworkEnforcers() []NetworkEnforcer {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
var result []NetworkEnforcer
|
||||
for _, plugin := range r.plugins {
|
||||
if ne, ok := plugin.(NetworkEnforcer); ok {
|
||||
result = append(result, ne)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetEndpointControllers returns all loaded EndpointController plugins.
|
||||
func (r *PluginRegistry) GetEndpointControllers() []EndpointController {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
var result []EndpointController
|
||||
for _, plugin := range r.plugins {
|
||||
if ec, ok := plugin.(EndpointController); ok {
|
||||
result = append(result, ec)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetWebGateways returns all loaded WebGateway plugins.
|
||||
func (r *PluginRegistry) GetWebGateways() []WebGateway {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
var result []WebGateway
|
||||
for _, plugin := range r.plugins {
|
||||
if wg, ok := plugin.(WebGateway); ok {
|
||||
result = append(result, wg)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// IsHealthy returns true if a plugin is currently healthy.
|
||||
func (r *PluginRegistry) IsHealthy(vendor string) bool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
h, ok := r.health[vendor]
|
||||
return ok && h.Status == PluginStatusHealthy
|
||||
}
|
||||
|
||||
// SetHealth updates the health status for a plugin.
|
||||
func (r *PluginRegistry) SetHealth(vendor string, health *PluginHealth) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.health[vendor] = health
|
||||
}
|
||||
|
||||
// GetHealth returns the health status snapshot for a plugin.
|
||||
func (r *PluginRegistry) GetHealth(vendor string) (*PluginHealth, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
h, ok := r.health[vendor]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
cp := *h
|
||||
return &cp, true
|
||||
}
|
||||
|
||||
// AllHealth returns health snapshots for all plugins.
|
||||
func (r *PluginRegistry) AllHealth() []PluginHealth {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
result := make([]PluginHealth, 0, len(r.health))
|
||||
for _, h := range r.health {
|
||||
result = append(result, *h)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// PluginCount returns the number of loaded plugins.
|
||||
func (r *PluginRegistry) PluginCount() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return len(r.plugins)
|
||||
}
|
||||
|
||||
// Vendors returns all loaded vendor names.
|
||||
func (r *PluginRegistry) Vendors() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
result := make([]string, 0, len(r.plugins))
|
||||
for v := range r.plugins {
|
||||
result = append(result, v)
|
||||
}
|
||||
return result
|
||||
}
|
||||
1225
internal/application/shadow_ai/shadow_ai_test.go
Normal file
1225
internal/application/shadow_ai/shadow_ai_test.go
Normal file
File diff suppressed because it is too large
Load diff
373
internal/application/shadow_ai/soc_integration.go
Normal file
373
internal/application/shadow_ai/soc_integration.go
Normal file
|
|
@ -0,0 +1,373 @@
|
|||
package shadow_ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ShadowAIController is the main orchestrator that ties together
|
||||
// detection, enforcement, SOC event emission, and statistics.
|
||||
type ShadowAIController struct {
|
||||
mu sync.RWMutex
|
||||
registry *PluginRegistry
|
||||
fallback *FallbackManager
|
||||
healthChecker *HealthChecker
|
||||
netDetector *NetworkDetector
|
||||
behavioral *BehavioralDetector
|
||||
docBridge *DocBridge
|
||||
approval *ApprovalEngine
|
||||
events []ShadowAIEvent // In-memory event store (bounded)
|
||||
maxEvents int
|
||||
socEventFn func(source, severity, category, description string, meta map[string]string) // Bridge to SOC event bus
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewShadowAIController creates the main Shadow AI Control orchestrator.
|
||||
func NewShadowAIController() *ShadowAIController {
|
||||
registry := NewPluginRegistry()
|
||||
RegisterDefaultPlugins(registry)
|
||||
return &ShadowAIController{
|
||||
registry: registry,
|
||||
fallback: NewFallbackManager(registry, "detect_only"),
|
||||
netDetector: NewNetworkDetector(),
|
||||
behavioral: NewBehavioralDetector(100),
|
||||
docBridge: NewDocBridge(),
|
||||
approval: NewApprovalEngine(),
|
||||
events: make([]ShadowAIEvent, 0, 1000),
|
||||
maxEvents: 10000,
|
||||
logger: slog.Default().With("component", "shadow-ai-controller"),
|
||||
}
|
||||
}
|
||||
|
||||
// SetSOCEventEmitter sets the function used to emit events into the SOC pipeline.
|
||||
func (c *ShadowAIController) SetSOCEventEmitter(fn func(source, severity, category, description string, meta map[string]string)) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.socEventFn = fn
|
||||
}
|
||||
|
||||
// Configure loads plugin configuration and initializes the integration layer.
|
||||
func (c *ShadowAIController) Configure(config *IntegrationConfig) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if err := c.registry.LoadPlugins(config); err != nil {
|
||||
return fmt.Errorf("failed to load plugins: %w", err)
|
||||
}
|
||||
|
||||
c.fallback = NewFallbackManager(c.registry, config.FallbackStrategy)
|
||||
c.fallback.SetEventLogger(func(event ShadowAIEvent) {
|
||||
c.recordEvent(event)
|
||||
})
|
||||
|
||||
interval := config.HealthCheckInterval
|
||||
if interval <= 0 {
|
||||
interval = 30 * time.Second
|
||||
}
|
||||
c.healthChecker = NewHealthChecker(c.registry, interval, func(vendor string, status PluginStatus, msg string) {
|
||||
c.emitSOCEvent("HIGH", "integration_health", msg, map[string]string{
|
||||
"vendor": vendor,
|
||||
"status": string(status),
|
||||
})
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartHealthChecker starts continuous plugin health monitoring.
|
||||
func (c *ShadowAIController) StartHealthChecker(ctx context.Context) {
|
||||
if c.healthChecker != nil {
|
||||
go c.healthChecker.Start(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessNetworkEvent analyzes a network event and enforces policy.
|
||||
func (c *ShadowAIController) ProcessNetworkEvent(ctx context.Context, event NetworkEvent) *ShadowAIEvent {
|
||||
detected := c.netDetector.Analyze(event)
|
||||
if detected == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Record behavioral data.
|
||||
c.behavioral.RecordAccess(event.User, event.Destination, event.DataSize)
|
||||
|
||||
// Attempt to block.
|
||||
enforcedBy, err := c.fallback.BlockDomain(ctx, event.Destination, fmt.Sprintf("Shadow AI: %s", detected.AIService))
|
||||
if err != nil {
|
||||
c.logger.Error("enforcement failed", "destination", event.Destination, "error", err)
|
||||
}
|
||||
|
||||
if enforcedBy != "" {
|
||||
detected.Action = "blocked"
|
||||
detected.EnforcedBy = enforcedBy
|
||||
} else {
|
||||
detected.Action = "detected"
|
||||
}
|
||||
|
||||
detected.ID = genEventID()
|
||||
c.recordEvent(*detected)
|
||||
|
||||
// Emit to SOC event bus.
|
||||
c.emitSOCEvent("HIGH", "shadow_ai_usage",
|
||||
fmt.Sprintf("Shadow AI access detected: %s → %s", event.User, detected.AIService),
|
||||
map[string]string{
|
||||
"user": event.User,
|
||||
"hostname": event.Hostname,
|
||||
"destination": event.Destination,
|
||||
"ai_service": detected.AIService,
|
||||
"action": detected.Action,
|
||||
"enforced_by": detected.EnforcedBy,
|
||||
},
|
||||
)
|
||||
|
||||
return detected
|
||||
}
|
||||
|
||||
// ScanContent scans text content for AI API keys.
|
||||
func (c *ShadowAIController) ScanContent(content string) string {
|
||||
return c.netDetector.SignatureDB().ScanForAPIKeys(content)
|
||||
}
|
||||
|
||||
// ManualBlock manually blocks a domain or IP.
|
||||
func (c *ShadowAIController) ManualBlock(ctx context.Context, req BlockRequest) error {
|
||||
switch req.TargetType {
|
||||
case "domain":
|
||||
_, err := c.fallback.BlockDomain(ctx, req.Target, req.Reason)
|
||||
return err
|
||||
case "ip":
|
||||
_, err := c.fallback.BlockIP(ctx, req.Target, req.Duration, req.Reason)
|
||||
return err
|
||||
case "host":
|
||||
_, err := c.fallback.IsolateHost(ctx, req.Target)
|
||||
return err
|
||||
default:
|
||||
return fmt.Errorf("unsupported target type: %s", req.TargetType)
|
||||
}
|
||||
}
|
||||
|
||||
// GetStats returns aggregate shadow AI statistics.
|
||||
func (c *ShadowAIController) GetStats(timeRange string) ShadowAIStats {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
cutoff := parseCutoff(timeRange)
|
||||
stats := ShadowAIStats{
|
||||
TimeRange: timeRange,
|
||||
ByService: make(map[string]int),
|
||||
ByDepartment: make(map[string]int),
|
||||
}
|
||||
|
||||
violatorMap := make(map[string]int)
|
||||
|
||||
for _, e := range c.events {
|
||||
if e.Timestamp.Before(cutoff) {
|
||||
continue
|
||||
}
|
||||
stats.Total++
|
||||
switch e.Action {
|
||||
case "blocked":
|
||||
stats.Blocked++
|
||||
case "allowed", "approved":
|
||||
stats.Approved++
|
||||
case "pending":
|
||||
stats.Pending++
|
||||
}
|
||||
if e.AIService != "" {
|
||||
stats.ByService[e.AIService]++
|
||||
}
|
||||
if dept, ok := e.Metadata["department"]; ok {
|
||||
stats.ByDepartment[dept]++
|
||||
}
|
||||
if e.UserID != "" {
|
||||
violatorMap[e.UserID]++
|
||||
}
|
||||
}
|
||||
|
||||
// Build top violators list (sorted desc).
|
||||
for uid, count := range violatorMap {
|
||||
stats.TopViolators = append(stats.TopViolators, Violator{UserID: uid, Attempts: count})
|
||||
}
|
||||
// Sort by attempts descending, limit to 10.
|
||||
for i := 0; i < len(stats.TopViolators); i++ {
|
||||
for j := i + 1; j < len(stats.TopViolators); j++ {
|
||||
if stats.TopViolators[j].Attempts > stats.TopViolators[i].Attempts {
|
||||
stats.TopViolators[i], stats.TopViolators[j] = stats.TopViolators[j], stats.TopViolators[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(stats.TopViolators) > 10 {
|
||||
stats.TopViolators = stats.TopViolators[:10]
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// GetEvents returns recent shadow AI events (newest first).
|
||||
func (c *ShadowAIController) GetEvents(limit int) []ShadowAIEvent {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
total := len(c.events)
|
||||
if total == 0 {
|
||||
return nil
|
||||
}
|
||||
start := total - limit
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
|
||||
// Return newest first.
|
||||
result := make([]ShadowAIEvent, 0, limit)
|
||||
for i := total - 1; i >= start; i-- {
|
||||
result = append(result, c.events[i])
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetEvent returns a single event by ID.
|
||||
func (c *ShadowAIController) GetEvent(id string) (*ShadowAIEvent, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
for i := len(c.events) - 1; i >= 0; i-- {
|
||||
if c.events[i].ID == id {
|
||||
cp := c.events[i]
|
||||
return &cp, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// IntegrationHealth returns health status of all plugins.
|
||||
func (c *ShadowAIController) IntegrationHealth() []PluginHealth {
|
||||
return c.registry.AllHealth()
|
||||
}
|
||||
|
||||
// VendorHealth returns health for a specific vendor.
|
||||
func (c *ShadowAIController) VendorHealth(vendor string) (*PluginHealth, bool) {
|
||||
return c.registry.GetHealth(vendor)
|
||||
}
|
||||
|
||||
// Registry returns the plugin registry for direct access.
|
||||
func (c *ShadowAIController) Registry() *PluginRegistry {
|
||||
return c.registry
|
||||
}
|
||||
|
||||
// NetworkDetector returns the network detector for configuration.
|
||||
func (c *ShadowAIController) NetworkDetector() *NetworkDetector {
|
||||
return c.netDetector
|
||||
}
|
||||
|
||||
// BehavioralDetector returns the behavioral detector.
|
||||
func (c *ShadowAIController) BehavioralDetector() *BehavioralDetector {
|
||||
return c.behavioral
|
||||
}
|
||||
|
||||
// DocBridge returns the document review bridge.
|
||||
func (c *ShadowAIController) DocBridge() *DocBridge {
|
||||
return c.docBridge
|
||||
}
|
||||
|
||||
// ApprovalEngine returns the approval workflow engine.
|
||||
func (c *ShadowAIController) ApprovalEngine() *ApprovalEngine {
|
||||
return c.approval
|
||||
}
|
||||
|
||||
// ReviewDocument scans a document and creates an approval request if needed.
|
||||
func (c *ShadowAIController) ReviewDocument(docID, content, userID string) (*ScanResult, *ApprovalRequest) {
|
||||
result := c.docBridge.ScanDocument(docID, content, userID)
|
||||
|
||||
// Create approval request based on data classification.
|
||||
var req *ApprovalRequest
|
||||
if result.Status != DocReviewBlocked {
|
||||
req = c.approval.SubmitRequest(userID, docID, result.DataClass)
|
||||
}
|
||||
|
||||
// Emit SOC event for tracking.
|
||||
c.emitSOCEvent("MEDIUM", "shadow_ai_usage",
|
||||
fmt.Sprintf("Document review: %s by %s — %s (%s)",
|
||||
docID, userID, result.Status, result.DataClass),
|
||||
map[string]string{
|
||||
"user": userID,
|
||||
"doc_id": docID,
|
||||
"status": string(result.Status),
|
||||
"data_class": string(result.DataClass),
|
||||
"pii_count": fmt.Sprintf("%d", len(result.PIIFound)),
|
||||
},
|
||||
)
|
||||
|
||||
return result, req
|
||||
}
|
||||
|
||||
// GenerateComplianceReport generates a compliance report for the given period.
|
||||
func (c *ShadowAIController) GenerateComplianceReport(period string) ComplianceReport {
|
||||
stats := c.GetStats(period)
|
||||
docStats := c.docBridge.Stats()
|
||||
return ComplianceReport{
|
||||
GeneratedAt: time.Now(),
|
||||
Period: period,
|
||||
TotalInteractions: stats.Total,
|
||||
BlockedAttempts: stats.Blocked,
|
||||
ApprovedReviews: stats.Approved,
|
||||
PIIDetected: docStats["redacted"] + docStats["blocked"],
|
||||
SecretsDetected: docStats["blocked"],
|
||||
AuditComplete: true,
|
||||
Regulations: []string{"GDPR", "SOC2", "EU AI Act Article 15"},
|
||||
}
|
||||
}
|
||||
|
||||
// --- Internal helpers ---
|
||||
|
||||
func (c *ShadowAIController) recordEvent(event ShadowAIEvent) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.events = append(c.events, event)
|
||||
|
||||
// Evict oldest events if over capacity.
|
||||
if len(c.events) > c.maxEvents {
|
||||
excess := len(c.events) - c.maxEvents
|
||||
c.events = c.events[excess:]
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ShadowAIController) emitSOCEvent(severity, category, description string, meta map[string]string) {
|
||||
c.mu.RLock()
|
||||
fn := c.socEventFn
|
||||
c.mu.RUnlock()
|
||||
|
||||
if fn != nil {
|
||||
fn("shadow-ai", severity, category, description, meta)
|
||||
}
|
||||
}
|
||||
|
||||
func parseCutoff(timeRange string) time.Time {
|
||||
switch timeRange {
|
||||
case "1h":
|
||||
return time.Now().Add(-1 * time.Hour)
|
||||
case "24h":
|
||||
return time.Now().Add(-24 * time.Hour)
|
||||
case "7d":
|
||||
return time.Now().Add(-7 * 24 * time.Hour)
|
||||
case "30d":
|
||||
return time.Now().Add(-30 * 24 * time.Hour)
|
||||
case "90d":
|
||||
return time.Now().Add(-90 * 24 * time.Hour)
|
||||
default:
|
||||
return time.Now().Add(-24 * time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
var eventCounter uint64
|
||||
var eventCounterMu sync.Mutex
|
||||
|
||||
func genEventID() string {
|
||||
eventCounterMu.Lock()
|
||||
eventCounter++
|
||||
id := eventCounter
|
||||
eventCounterMu.Unlock()
|
||||
return fmt.Sprintf("sai-%d-%d", time.Now().UnixMilli(), id)
|
||||
}
|
||||
159
internal/application/sidecar/client.go
Normal file
159
internal/application/sidecar/client.go
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
package sidecar
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
domsoc "github.com/syntrex/gomcp/internal/domain/soc"
|
||||
)
|
||||
|
||||
// BusClient sends security events to the SOC Event Bus via HTTP POST.
|
||||
type BusClient struct {
|
||||
baseURL string
|
||||
sensorID string
|
||||
apiKey string
|
||||
httpClient *http.Client
|
||||
maxRetries int
|
||||
}
|
||||
|
||||
// NewBusClient creates a client for the SOC Event Bus.
|
||||
func NewBusClient(baseURL, sensorID, apiKey string) *BusClient {
|
||||
return &BusClient{
|
||||
baseURL: baseURL,
|
||||
sensorID: sensorID,
|
||||
apiKey: apiKey,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
MaxIdleConnsPerHost: 5,
|
||||
},
|
||||
},
|
||||
maxRetries: 3,
|
||||
}
|
||||
}
|
||||
|
||||
// ingestPayload matches the SOC ingest API expected JSON.
|
||||
type ingestPayload struct {
|
||||
Source string `json:"source"`
|
||||
SensorID string `json:"sensor_id"`
|
||||
SensorKey string `json:"sensor_key,omitempty"`
|
||||
Severity string `json:"severity"`
|
||||
Category string `json:"category"`
|
||||
Subcategory string `json:"subcategory,omitempty"`
|
||||
Confidence float64 `json:"confidence"`
|
||||
Description string `json:"description"`
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// SendEvent posts a SOCEvent to the Event Bus.
|
||||
// Accepts context for graceful cancellation during retries (L-2 fix).
|
||||
func (c *BusClient) SendEvent(ctx context.Context, evt *domsoc.SOCEvent) error {
|
||||
payload := ingestPayload{
|
||||
Source: string(evt.Source),
|
||||
SensorID: c.sensorID,
|
||||
SensorKey: c.apiKey,
|
||||
Severity: string(evt.Severity),
|
||||
Category: evt.Category,
|
||||
Subcategory: evt.Subcategory,
|
||||
Confidence: evt.Confidence,
|
||||
Description: evt.Description,
|
||||
SessionID: evt.SessionID,
|
||||
Metadata: evt.Metadata,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sidecar: marshal event: %w", err)
|
||||
}
|
||||
|
||||
url := c.baseURL + "/api/v1/soc/events"
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= c.maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
// Context-aware backoff: cancellable during shutdown (H-1 fix).
|
||||
backoff := time.Duration(attempt*attempt) * 500 * time.Millisecond
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("sidecar: send cancelled during retry: %w", ctx.Err())
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("sidecar: create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
slog.Warn("sidecar: bus POST failed, retrying",
|
||||
"attempt", attempt+1, "error", err)
|
||||
continue
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
return nil
|
||||
}
|
||||
|
||||
lastErr = fmt.Errorf("bus returned %d", resp.StatusCode)
|
||||
if resp.StatusCode >= 400 && resp.StatusCode < 500 {
|
||||
// Client error — don't retry.
|
||||
return lastErr
|
||||
}
|
||||
slog.Warn("sidecar: bus returned server error, retrying",
|
||||
"attempt", attempt+1, "status", resp.StatusCode)
|
||||
}
|
||||
|
||||
return fmt.Errorf("sidecar: exhausted retries: %w", lastErr)
|
||||
}
|
||||
|
||||
// Heartbeat sends a sensor heartbeat to the Event Bus.
|
||||
func (c *BusClient) Heartbeat() error {
|
||||
payload := map[string]string{
|
||||
"sensor_id": c.sensorID,
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sidecar: marshal heartbeat: %w", err)
|
||||
}
|
||||
|
||||
url := c.baseURL + "/api/soc/sensors/heartbeat"
|
||||
req, err := http.NewRequest("POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("heartbeat returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Healthy checks if the bus is reachable (M-4 fix: /healthz not /health).
|
||||
func (c *BusClient) Healthy() bool {
|
||||
resp, err := c.httpClient.Get(c.baseURL + "/healthz")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
resp.Body.Close()
|
||||
return resp.StatusCode == http.StatusOK
|
||||
}
|
||||
214
internal/application/sidecar/parser.go
Normal file
214
internal/application/sidecar/parser.go
Normal file
|
|
@ -0,0 +1,214 @@
|
|||
// Package sidecar implements the Universal Sidecar (§5.5) — a zero-dependency
|
||||
// Go binary that runs alongside SENTINEL sensors, tails their STDOUT/logs,
|
||||
// and pushes parsed security events to the SOC Event Bus.
|
||||
package sidecar
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
domsoc "github.com/syntrex/gomcp/internal/domain/soc"
|
||||
)
|
||||
|
||||
// Parser converts a raw log line into a SOCEvent.
|
||||
// Returns nil, false if the line is not a security event.
|
||||
type Parser interface {
|
||||
Parse(line string) (*domsoc.SOCEvent, bool)
|
||||
}
|
||||
|
||||
// ── sentinel-core Parser ─────────────────────────────────────────────────────
|
||||
|
||||
// SentinelCoreParser parses sentinel-core detection output.
|
||||
// Expected format: [DETECT] engine=<name> confidence=<float> pattern=<desc> [severity=<sev>]
|
||||
type SentinelCoreParser struct{}
|
||||
|
||||
var coreDetectRe = regexp.MustCompile(
|
||||
`\[DETECT\]\s+engine=(\S+)\s+confidence=([0-9.]+)\s+pattern=(.+?)(?:\s+severity=(\S+))?$`)
|
||||
|
||||
func (p *SentinelCoreParser) Parse(line string) (*domsoc.SOCEvent, bool) {
|
||||
m := coreDetectRe.FindStringSubmatch(strings.TrimSpace(line))
|
||||
if m == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
engine := m[1]
|
||||
conf, _ := strconv.ParseFloat(m[2], 64)
|
||||
pattern := m[3]
|
||||
severity := mapConfidenceToSeverity(conf)
|
||||
if m[4] != "" {
|
||||
severity = domsoc.EventSeverity(strings.ToUpper(m[4]))
|
||||
}
|
||||
|
||||
evt := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, severity, engine,
|
||||
engine+": "+pattern)
|
||||
evt.Confidence = conf
|
||||
evt.Subcategory = pattern
|
||||
return &evt, true
|
||||
}
|
||||
|
||||
// ── shield Parser ────────────────────────────────────────────────────────────
|
||||
|
||||
// ShieldParser parses shield network block logs.
|
||||
// Expected format: BLOCKED protocol=<proto> reason=<reason> source_ip=<ip>
|
||||
type ShieldParser struct{}
|
||||
|
||||
var shieldBlockRe = regexp.MustCompile(
|
||||
`BLOCKED\s+protocol=(\S+)\s+reason=(.+?)\s+source_ip=(\S+)`)
|
||||
|
||||
func (p *ShieldParser) Parse(line string) (*domsoc.SOCEvent, bool) {
|
||||
m := shieldBlockRe.FindStringSubmatch(strings.TrimSpace(line))
|
||||
if m == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
protocol := m[1]
|
||||
reason := m[2]
|
||||
sourceIP := m[3]
|
||||
|
||||
evt := domsoc.NewSOCEvent(domsoc.SourceShield, domsoc.SeverityMedium, "network_block",
|
||||
"Shield blocked "+protocol+" from "+sourceIP+": "+reason)
|
||||
evt.Subcategory = protocol
|
||||
evt.Metadata = map[string]string{
|
||||
"source_ip": sourceIP,
|
||||
"protocol": protocol,
|
||||
"reason": reason,
|
||||
}
|
||||
return &evt, true
|
||||
}
|
||||
|
||||
// ── immune Parser ────────────────────────────────────────────────────────────
|
||||
|
||||
// ImmuneParser parses immune system anomaly/response logs.
|
||||
// Expected format: [ANOMALY] type=<type> score=<float> detail=<text>
|
||||
//
|
||||
// or: [RESPONSE] action=<action> target=<target> reason=<text>
|
||||
type ImmuneParser struct{}
|
||||
|
||||
var immuneAnomalyRe = regexp.MustCompile(
|
||||
`\[ANOMALY\]\s+type=(\S+)\s+score=([0-9.]+)\s+detail=(.+)`)
|
||||
var immuneResponseRe = regexp.MustCompile(
|
||||
`\[RESPONSE\]\s+action=(\S+)\s+target=(\S+)\s+reason=(.+)`)
|
||||
|
||||
func (p *ImmuneParser) Parse(line string) (*domsoc.SOCEvent, bool) {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
|
||||
if m := immuneAnomalyRe.FindStringSubmatch(trimmed); m != nil {
|
||||
anomalyType := m[1]
|
||||
score, _ := strconv.ParseFloat(m[2], 64)
|
||||
detail := m[3]
|
||||
|
||||
evt := domsoc.NewSOCEvent(domsoc.SourceImmune, mapConfidenceToSeverity(score),
|
||||
"anomaly", "Immune anomaly: "+anomalyType+": "+detail)
|
||||
evt.Confidence = score
|
||||
evt.Subcategory = anomalyType
|
||||
return &evt, true
|
||||
}
|
||||
|
||||
if m := immuneResponseRe.FindStringSubmatch(trimmed); m != nil {
|
||||
action := m[1]
|
||||
target := m[2]
|
||||
reason := m[3]
|
||||
|
||||
evt := domsoc.NewSOCEvent(domsoc.SourceImmune, domsoc.SeverityHigh,
|
||||
"immune_response", "Immune response: "+action+" on "+target+": "+reason)
|
||||
evt.Subcategory = action
|
||||
evt.Metadata = map[string]string{
|
||||
"action": action,
|
||||
"target": target,
|
||||
"reason": reason,
|
||||
}
|
||||
return &evt, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// ── Generic Parser ───────────────────────────────────────────────────────────
|
||||
|
||||
// GenericParser uses a configurable regex with named groups.
|
||||
// Named groups: "category", "severity", "description", "confidence".
|
||||
type GenericParser struct {
|
||||
Pattern *regexp.Regexp
|
||||
Source domsoc.EventSource
|
||||
}
|
||||
|
||||
func NewGenericParser(pattern string, source domsoc.EventSource) (*GenericParser, error) {
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &GenericParser{Pattern: re, Source: source}, nil
|
||||
}
|
||||
|
||||
func (p *GenericParser) Parse(line string) (*domsoc.SOCEvent, bool) {
|
||||
m := p.Pattern.FindStringSubmatch(strings.TrimSpace(line))
|
||||
if m == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
names := p.Pattern.SubexpNames()
|
||||
groups := map[string]string{}
|
||||
for i, name := range names {
|
||||
if i > 0 && name != "" {
|
||||
groups[name] = m[i]
|
||||
}
|
||||
}
|
||||
|
||||
category := groups["category"]
|
||||
if category == "" {
|
||||
category = "generic"
|
||||
}
|
||||
description := groups["description"]
|
||||
if description == "" {
|
||||
description = line
|
||||
}
|
||||
severity := domsoc.SeverityMedium
|
||||
if s, ok := groups["severity"]; ok && s != "" {
|
||||
severity = domsoc.EventSeverity(strings.ToUpper(s))
|
||||
}
|
||||
confidence := 0.5
|
||||
if c, ok := groups["confidence"]; ok {
|
||||
if f, err := strconv.ParseFloat(c, 64); err == nil {
|
||||
confidence = f
|
||||
}
|
||||
}
|
||||
|
||||
evt := domsoc.NewSOCEvent(p.Source, severity, category, description)
|
||||
evt.Confidence = confidence
|
||||
return &evt, true
|
||||
}
|
||||
|
||||
// ── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
// ParserForSensor returns the appropriate parser for a sensor type.
|
||||
func ParserForSensor(sensorType string) Parser {
|
||||
switch strings.ToLower(sensorType) {
|
||||
case "sentinel-core":
|
||||
return &SentinelCoreParser{}
|
||||
case "shield":
|
||||
return &ShieldParser{}
|
||||
case "immune":
|
||||
return &ImmuneParser{}
|
||||
default:
|
||||
slog.Warn("sidecar: unknown sensor type, using sentinel-core parser as fallback",
|
||||
"sensor_type", sensorType)
|
||||
return &SentinelCoreParser{} // fallback
|
||||
}
|
||||
}
|
||||
|
||||
func mapConfidenceToSeverity(conf float64) domsoc.EventSeverity {
|
||||
switch {
|
||||
case conf >= 0.9:
|
||||
return domsoc.SeverityCritical
|
||||
case conf >= 0.7:
|
||||
return domsoc.SeverityHigh
|
||||
case conf >= 0.5:
|
||||
return domsoc.SeverityMedium
|
||||
case conf >= 0.3:
|
||||
return domsoc.SeverityLow
|
||||
default:
|
||||
return domsoc.SeverityInfo
|
||||
}
|
||||
}
|
||||
157
internal/application/sidecar/sidecar.go
Normal file
157
internal/application/sidecar/sidecar.go
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
package sidecar
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config holds sidecar runtime configuration.
|
||||
type Config struct {
|
||||
SensorType string // sentinel-core, shield, immune, generic
|
||||
LogPath string // Path to sensor log file, or "stdin"
|
||||
BusURL string // SOC Event Bus URL (e.g., http://localhost:9100)
|
||||
SensorID string // Sensor registration ID
|
||||
APIKey string // Sensor API key
|
||||
PollInterval time.Duration // Log file poll interval
|
||||
}
|
||||
|
||||
// Stats tracks sidecar runtime metrics (thread-safe via atomic).
|
||||
type Stats struct {
|
||||
LinesRead atomic.Int64
|
||||
EventsSent atomic.Int64
|
||||
Errors atomic.Int64
|
||||
StartedAt time.Time
|
||||
}
|
||||
|
||||
// StatsSnapshot is a non-atomic copy for reading/logging.
|
||||
type StatsSnapshot struct {
|
||||
LinesRead int64 `json:"lines_read"`
|
||||
EventsSent int64 `json:"events_sent"`
|
||||
Errors int64 `json:"errors"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
}
|
||||
|
||||
// Sidecar is the main orchestrator: tailer → parser → bus client.
|
||||
type Sidecar struct {
|
||||
config Config
|
||||
parser Parser
|
||||
client *BusClient
|
||||
tailer *Tailer
|
||||
stats Stats
|
||||
}
|
||||
|
||||
// New creates a Sidecar with the given config.
|
||||
func New(cfg Config) *Sidecar {
|
||||
return &Sidecar{
|
||||
config: cfg,
|
||||
parser: ParserForSensor(cfg.SensorType),
|
||||
client: NewBusClient(cfg.BusURL, cfg.SensorID, cfg.APIKey),
|
||||
tailer: NewTailer(cfg.PollInterval),
|
||||
stats: Stats{StartedAt: time.Now()},
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the sidecar pipeline: tail → parse → send.
|
||||
// Blocks until ctx is cancelled.
|
||||
func (s *Sidecar) Run(ctx context.Context) error {
|
||||
slog.Info("sidecar: starting",
|
||||
"sensor_type", s.config.SensorType,
|
||||
"log_path", s.config.LogPath,
|
||||
"bus_url", s.config.BusURL,
|
||||
"sensor_id", s.config.SensorID,
|
||||
)
|
||||
|
||||
// Start line source.
|
||||
var lines <-chan string
|
||||
if s.config.LogPath == "stdin" || s.config.LogPath == "-" {
|
||||
lines = s.tailer.FollowStdin(ctx)
|
||||
} else {
|
||||
var err error
|
||||
lines, err = s.tailer.FollowFile(ctx, s.config.LogPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sidecar: open log: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Heartbeat goroutine.
|
||||
go s.heartbeatLoop(ctx)
|
||||
|
||||
// Main pipeline loop (shared with RunReader).
|
||||
return s.processLines(ctx, lines)
|
||||
}
|
||||
|
||||
// RunReader runs the sidecar from any io.Reader (for testing).
|
||||
func (s *Sidecar) RunReader(ctx context.Context, r io.Reader) error {
|
||||
lines := s.tailer.FollowReader(ctx, r)
|
||||
return s.processLines(ctx, lines)
|
||||
}
|
||||
|
||||
// processLines is the shared pipeline loop: parse → send → stats.
|
||||
// Extracted to DRY between Run() and RunReader() (H-3 fix).
|
||||
func (s *Sidecar) processLines(ctx context.Context, lines <-chan string) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Info("sidecar: shutting down",
|
||||
"lines_read", s.stats.LinesRead.Load(),
|
||||
"events_sent", s.stats.EventsSent.Load(),
|
||||
"errors", s.stats.Errors.Load(),
|
||||
)
|
||||
return nil
|
||||
|
||||
case line, ok := <-lines:
|
||||
if !ok {
|
||||
slog.Info("sidecar: input closed")
|
||||
return nil
|
||||
}
|
||||
|
||||
s.stats.LinesRead.Add(1)
|
||||
|
||||
evt, ok := s.parser.Parse(line)
|
||||
if !ok {
|
||||
continue // Not a security event.
|
||||
}
|
||||
|
||||
evt.SensorID = s.config.SensorID
|
||||
if err := s.client.SendEvent(ctx, evt); err != nil {
|
||||
s.stats.Errors.Add(1)
|
||||
slog.Error("sidecar: send failed",
|
||||
"error", err,
|
||||
"category", evt.Category,
|
||||
)
|
||||
continue
|
||||
}
|
||||
s.stats.EventsSent.Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetStats returns a snapshot of current runtime metrics (thread-safe).
|
||||
func (s *Sidecar) GetStats() StatsSnapshot {
|
||||
return StatsSnapshot{
|
||||
LinesRead: s.stats.LinesRead.Load(),
|
||||
EventsSent: s.stats.EventsSent.Load(),
|
||||
Errors: s.stats.Errors.Load(),
|
||||
StartedAt: s.stats.StartedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Sidecar) heartbeatLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.client.Heartbeat(); err != nil {
|
||||
slog.Warn("sidecar: heartbeat failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
306
internal/application/sidecar/sidecar_test.go
Normal file
306
internal/application/sidecar/sidecar_test.go
Normal file
|
|
@ -0,0 +1,306 @@
|
|||
package sidecar
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ── Parser Tests ─────────────────────────────────────────────────────────────
|
||||
|
||||
func TestSentinelCoreParser(t *testing.T) {
|
||||
p := &SentinelCoreParser{}
|
||||
|
||||
tests := []struct {
|
||||
line string
|
||||
wantOK bool
|
||||
category string
|
||||
confMin float64
|
||||
}{
|
||||
{"[DETECT] engine=jailbreak confidence=0.95 pattern=DAN prompt", true, "jailbreak", 0.9},
|
||||
{"[DETECT] engine=injection confidence=0.6 pattern=ignore_previous", true, "injection", 0.5},
|
||||
{"[DETECT] engine=exfiltration confidence=0.3 pattern=tool_call severity=HIGH", true, "exfiltration", 0.2},
|
||||
{"INFO: Engine loaded successfully", false, "", 0},
|
||||
{"", false, "", 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
evt, ok := p.Parse(tt.line)
|
||||
if ok != tt.wantOK {
|
||||
t.Errorf("Parse(%q) ok=%v, want %v", tt.line, ok, tt.wantOK)
|
||||
continue
|
||||
}
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if evt.Category != tt.category {
|
||||
t.Errorf("Parse(%q) category=%q, want %q", tt.line, evt.Category, tt.category)
|
||||
}
|
||||
if evt.Confidence < tt.confMin {
|
||||
t.Errorf("Parse(%q) confidence=%.2f, want >=%.2f", tt.line, evt.Confidence, tt.confMin)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestShieldParser(t *testing.T) {
|
||||
p := &ShieldParser{}
|
||||
|
||||
tests := []struct {
|
||||
line string
|
||||
wantOK bool
|
||||
proto string
|
||||
ip string
|
||||
}{
|
||||
{"BLOCKED protocol=tcp reason=port_scan source_ip=192.168.1.100", true, "tcp", "192.168.1.100"},
|
||||
{"BLOCKED protocol=udp reason=dns_exfil source_ip=10.0.0.5", true, "udp", "10.0.0.5"},
|
||||
{"ALLOWED protocol=https from 1.2.3.4", false, "", ""},
|
||||
{"", false, "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
evt, ok := p.Parse(tt.line)
|
||||
if ok != tt.wantOK {
|
||||
t.Errorf("Parse(%q) ok=%v, want %v", tt.line, ok, tt.wantOK)
|
||||
continue
|
||||
}
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if evt.Metadata["protocol"] != tt.proto {
|
||||
t.Errorf("protocol=%q, want %q", evt.Metadata["protocol"], tt.proto)
|
||||
}
|
||||
if evt.Metadata["source_ip"] != tt.ip {
|
||||
t.Errorf("source_ip=%q, want %q", evt.Metadata["source_ip"], tt.ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestImmuneParser(t *testing.T) {
|
||||
p := &ImmuneParser{}
|
||||
|
||||
tests := []struct {
|
||||
line string
|
||||
wantOK bool
|
||||
category string
|
||||
}{
|
||||
{"[ANOMALY] type=drift score=0.85 detail=behavior shift detected", true, "anomaly"},
|
||||
{"[RESPONSE] action=quarantine target=session-123 reason=high risk", true, "immune_response"},
|
||||
{"[INFO] system healthy", false, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
evt, ok := p.Parse(tt.line)
|
||||
if ok != tt.wantOK {
|
||||
t.Errorf("Parse(%q) ok=%v, want %v", tt.line, ok, tt.wantOK)
|
||||
continue
|
||||
}
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if evt.Category != tt.category {
|
||||
t.Errorf("category=%q, want %q", evt.Category, tt.category)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericParser(t *testing.T) {
|
||||
p, err := NewGenericParser(
|
||||
`ALERT\s+(?P<category>\S+)\s+(?P<severity>\S+)\s+(?P<description>.+)`,
|
||||
"external",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("NewGenericParser: %v", err)
|
||||
}
|
||||
|
||||
evt, ok := p.Parse("ALERT injection HIGH suspicious sql in query string")
|
||||
if !ok {
|
||||
t.Fatal("expected match")
|
||||
}
|
||||
if evt.Category != "injection" {
|
||||
t.Errorf("category=%q, want injection", evt.Category)
|
||||
}
|
||||
if string(evt.Severity) != "HIGH" {
|
||||
t.Errorf("severity=%q, want HIGH", evt.Severity)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParserForSensor(t *testing.T) {
|
||||
tests := map[string]string{
|
||||
"sentinel-core": "*sidecar.SentinelCoreParser",
|
||||
"shield": "*sidecar.ShieldParser",
|
||||
"immune": "*sidecar.ImmuneParser",
|
||||
"unknown": "*sidecar.SentinelCoreParser", // fallback
|
||||
}
|
||||
for sensorType, wantType := range tests {
|
||||
p := ParserForSensor(sensorType)
|
||||
if p == nil {
|
||||
t.Errorf("ParserForSensor(%q) returned nil", sensorType)
|
||||
continue
|
||||
}
|
||||
gotType := fmt.Sprintf("%T", p)
|
||||
if gotType != wantType {
|
||||
t.Errorf("ParserForSensor(%q) = %s, want %s", sensorType, gotType, wantType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tailer Tests ─────────────────────────────────────────────────────────────
|
||||
|
||||
func TestTailer_FollowReader(t *testing.T) {
|
||||
input := "[DETECT] engine=jailbreak confidence=0.95 pattern=DAN\nINFO: done\n[DETECT] engine=exfil confidence=0.7 pattern=tool_call\n"
|
||||
reader := strings.NewReader(input)
|
||||
|
||||
tailer := NewTailer(50 * time.Millisecond)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ch := tailer.FollowReader(ctx, reader)
|
||||
|
||||
var lines []string
|
||||
for line := range ch {
|
||||
lines = append(lines, line)
|
||||
}
|
||||
|
||||
if len(lines) != 3 {
|
||||
t.Fatalf("expected 3 lines, got %d: %v", len(lines), lines)
|
||||
}
|
||||
|
||||
if lines[0] != "[DETECT] engine=jailbreak confidence=0.95 pattern=DAN" {
|
||||
t.Errorf("line[0]=%q", lines[0])
|
||||
}
|
||||
}
|
||||
|
||||
// ── BusClient Tests ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestBusClient_SendEvent(t *testing.T) {
|
||||
var received []map[string]any
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/v1/soc/events" {
|
||||
var payload map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&payload)
|
||||
received = append(received, payload)
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
client := NewBusClient(ts.URL, "test-sensor", "test-key")
|
||||
|
||||
p := &SentinelCoreParser{}
|
||||
evt, ok := p.Parse("[DETECT] engine=jailbreak confidence=0.95 pattern=DAN")
|
||||
if !ok {
|
||||
t.Fatal("parse failed")
|
||||
}
|
||||
|
||||
err := client.SendEvent(context.Background(), evt)
|
||||
if err != nil {
|
||||
t.Fatalf("SendEvent: %v", err)
|
||||
}
|
||||
|
||||
if len(received) != 1 {
|
||||
t.Fatalf("expected 1 received event, got %d", len(received))
|
||||
}
|
||||
|
||||
if received[0]["source"] != "sentinel-core" {
|
||||
t.Errorf("source=%v, want sentinel-core", received[0]["source"])
|
||||
}
|
||||
if received[0]["category"] != "jailbreak" {
|
||||
t.Errorf("category=%v, want jailbreak", received[0]["category"])
|
||||
}
|
||||
if received[0]["sensor_id"] != "test-sensor" {
|
||||
t.Errorf("sensor_id=%v, want test-sensor", received[0]["sensor_id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBusClient_Healthy(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
client := NewBusClient(ts.URL, "s1", "k1")
|
||||
if !client.Healthy() {
|
||||
t.Error("expected healthy")
|
||||
}
|
||||
|
||||
// Unreachable server.
|
||||
client2 := NewBusClient("http://localhost:1", "s2", "k2")
|
||||
if client2.Healthy() {
|
||||
t.Error("expected unhealthy")
|
||||
}
|
||||
}
|
||||
|
||||
// ── E2E Pipeline Test ────────────────────────────────────────────────────────
|
||||
|
||||
func TestSidecar_E2E_Pipeline(t *testing.T) {
|
||||
var receivedEvents []map[string]any
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v1/soc/events":
|
||||
var payload map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&payload)
|
||||
receivedEvents = append(receivedEvents, payload)
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
case "/health":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
input := strings.Join([]string{
|
||||
"[DETECT] engine=jailbreak confidence=0.95 pattern=DAN",
|
||||
"INFO: processing complete",
|
||||
"[DETECT] engine=injection confidence=0.7 pattern=ignore_previous",
|
||||
"DEBUG: internal state update",
|
||||
"[DETECT] engine=exfiltration confidence=0.5 pattern=tool_call",
|
||||
}, "\n")
|
||||
|
||||
cfg := Config{
|
||||
SensorType: "sentinel-core",
|
||||
LogPath: "stdin",
|
||||
BusURL: ts.URL,
|
||||
SensorID: "e2e-test-sensor",
|
||||
APIKey: "test-key",
|
||||
}
|
||||
|
||||
sc := New(cfg)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := sc.RunReader(ctx, strings.NewReader(input))
|
||||
if err != nil {
|
||||
t.Fatalf("RunReader: %v", err)
|
||||
}
|
||||
|
||||
stats := sc.GetStats()
|
||||
if stats.LinesRead != 5 {
|
||||
t.Errorf("LinesRead=%d, want 5", stats.LinesRead)
|
||||
}
|
||||
if stats.EventsSent != 3 {
|
||||
t.Errorf("EventsSent=%d, want 3 (3 DETECT lines, 2 skipped)", stats.EventsSent)
|
||||
}
|
||||
|
||||
if len(receivedEvents) != 3 {
|
||||
t.Fatalf("received %d events, want 3", len(receivedEvents))
|
||||
}
|
||||
|
||||
// Verify first event.
|
||||
first := receivedEvents[0]
|
||||
if first["category"] != "jailbreak" {
|
||||
t.Errorf("first event category=%v, want jailbreak", first["category"])
|
||||
}
|
||||
if first["sensor_id"] != "e2e-test-sensor" {
|
||||
t.Errorf("first event sensor_id=%v, want e2e-test-sensor", first["sensor_id"])
|
||||
}
|
||||
}
|
||||
162
internal/application/sidecar/tailer.go
Normal file
162
internal/application/sidecar/tailer.go
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
package sidecar
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Tailer follows a log file or stdin, emitting lines via a channel.
|
||||
type Tailer struct {
|
||||
pollInterval time.Duration
|
||||
}
|
||||
|
||||
// NewTailer creates a Tailer with the given poll interval for file changes.
|
||||
func NewTailer(pollInterval time.Duration) *Tailer {
|
||||
if pollInterval <= 0 {
|
||||
pollInterval = 200 * time.Millisecond
|
||||
}
|
||||
return &Tailer{pollInterval: pollInterval}
|
||||
}
|
||||
|
||||
// FollowFile tails a file, seeking to end on start.
|
||||
// Sends lines on the returned channel until ctx is cancelled.
|
||||
func (t *Tailer) FollowFile(ctx context.Context, path string) (<-chan string, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Seek to end — only process new lines.
|
||||
if _, err := f.Seek(0, io.SeekEnd); err != nil {
|
||||
f.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ch := make(chan string, 256)
|
||||
|
||||
go func() {
|
||||
defer f.Close()
|
||||
defer close(ch)
|
||||
|
||||
// H-2 fix: Use Scanner with 1MB max line size to prevent OOM.
|
||||
const maxLineSize = 1 << 20 // 1MB
|
||||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line != "" {
|
||||
select {
|
||||
case ch <- line:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Scanner stopped — either EOF or error.
|
||||
if err := scanner.Err(); err != nil {
|
||||
slog.Error("sidecar: read error", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// EOF — wait and check for rotation.
|
||||
time.Sleep(t.pollInterval)
|
||||
|
||||
if t.fileRotated(f, path) {
|
||||
slog.Info("sidecar: log rotated, reopening", "path", path)
|
||||
f.Close()
|
||||
newF, err := os.Open(path)
|
||||
if err != nil {
|
||||
slog.Error("sidecar: reopen failed", "path", path, "error", err)
|
||||
return
|
||||
}
|
||||
f = newF
|
||||
scanner = bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
} else {
|
||||
// Same file, re-create scanner at current position.
|
||||
scanner = bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// FollowStdin reads from stdin line by line.
|
||||
func (t *Tailer) FollowStdin(ctx context.Context) <-chan string {
|
||||
ch := make(chan string, 256)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
for scanner.Scan() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
line := scanner.Text()
|
||||
if line != "" {
|
||||
select {
|
||||
case ch <- line:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ch
|
||||
}
|
||||
|
||||
// FollowReader reads from any io.Reader (for testing).
|
||||
func (t *Tailer) FollowReader(ctx context.Context, r io.Reader) <-chan string {
|
||||
ch := make(chan string, 256)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
line := scanner.Text()
|
||||
if line != "" {
|
||||
select {
|
||||
case ch <- line:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ch
|
||||
}
|
||||
|
||||
// fileRotated checks if the file path now points to a different inode.
|
||||
func (t *Tailer) fileRotated(current *os.File, path string) bool {
|
||||
curInfo, err1 := current.Stat()
|
||||
newInfo, err2 := os.Stat(path)
|
||||
if err1 != nil || err2 != nil {
|
||||
return false
|
||||
}
|
||||
return !os.SameFile(curInfo, newInfo)
|
||||
}
|
||||
527
internal/application/soc/e2e_test.go
Normal file
527
internal/application/soc/e2e_test.go
Normal file
|
|
@ -0,0 +1,527 @@
|
|||
package soc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
domsoc "github.com/syntrex/gomcp/internal/domain/soc"
|
||||
"github.com/syntrex/gomcp/internal/infrastructure/audit"
|
||||
"github.com/syntrex/gomcp/internal/infrastructure/sqlite"
|
||||
)
|
||||
|
||||
// newTestServiceWithLogger creates a SOC service backed by in-memory SQLite WITH a decision logger.
|
||||
func newTestServiceWithLogger(t *testing.T) *Service {
|
||||
t.Helper()
|
||||
db, err := sqlite.OpenMemory()
|
||||
require.NoError(t, err)
|
||||
|
||||
repo, err := sqlite.NewSOCRepo(db)
|
||||
require.NoError(t, err)
|
||||
|
||||
logger, err := audit.NewDecisionLogger(t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Close logger BEFORE TempDir cleanup (Windows file locking).
|
||||
t.Cleanup(func() {
|
||||
logger.Close()
|
||||
db.Close()
|
||||
})
|
||||
|
||||
return NewService(repo, logger)
|
||||
}
|
||||
|
||||
// --- E2E: Full Pipeline (Ingest → Correlation → Incident → Playbook) ---
|
||||
|
||||
func TestE2E_FullPipeline_IngestToIncident(t *testing.T) {
|
||||
svc := newTestServiceWithLogger(t)
|
||||
|
||||
// Step 1: Ingest a jailbreak event.
|
||||
evt1 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityHigh, "jailbreak", "detected jailbreak attempt")
|
||||
evt1.SensorID = "sensor-e2e-1"
|
||||
id1, inc1, err := svc.IngestEvent(evt1)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, id1)
|
||||
assert.Nil(t, inc1, "single event should not trigger correlation")
|
||||
|
||||
// Step 2: Ingest a tool_abuse event from same source — triggers SOC-CR-001.
|
||||
evt2 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityCritical, "tool_abuse", "tool abuse detected")
|
||||
evt2.SensorID = "sensor-e2e-1"
|
||||
id2, inc2, err := svc.IngestEvent(evt2)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, id2)
|
||||
|
||||
// Correlation rule SOC-CR-001 (jailbreak + tool_abuse) should trigger an incident.
|
||||
require.NotNil(t, inc2, "jailbreak + tool_abuse should create an incident")
|
||||
assert.Equal(t, domsoc.SeverityCritical, inc2.Severity)
|
||||
assert.Equal(t, "Multi-stage Jailbreak", inc2.Title)
|
||||
assert.NotEmpty(t, inc2.ID)
|
||||
assert.NotEmpty(t, inc2.Events, "incident should reference triggering events")
|
||||
|
||||
// Step 3: Verify incident is persisted.
|
||||
gotInc, err := svc.GetIncident(inc2.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, inc2.ID, gotInc.ID)
|
||||
|
||||
// Step 4: Verify decision chain integrity.
|
||||
dash, err := svc.Dashboard()
|
||||
require.NoError(t, err)
|
||||
assert.True(t, dash.ChainValid, "decision chain should be valid")
|
||||
assert.Greater(t, dash.TotalEvents, 0)
|
||||
}
|
||||
|
||||
func TestE2E_TemporalSequenceCorrelation(t *testing.T) {
|
||||
svc := newTestServiceWithLogger(t)
|
||||
|
||||
// Sequence rule SOC-CR-010: auth_bypass → tool_abuse (ordered).
|
||||
evt1 := domsoc.NewSOCEvent(domsoc.SourceShield, domsoc.SeverityHigh, "auth_bypass", "brute force detected")
|
||||
evt1.SensorID = "sensor-seq-1"
|
||||
_, _, err := svc.IngestEvent(evt1)
|
||||
require.NoError(t, err)
|
||||
|
||||
evt2 := domsoc.NewSOCEvent(domsoc.SourceShield, domsoc.SeverityHigh, "tool_abuse", "tool escalation")
|
||||
evt2.SensorID = "sensor-seq-1"
|
||||
_, inc, err := svc.IngestEvent(evt2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should trigger either SOC-CR-010 (sequence) or another matching rule.
|
||||
if inc != nil {
|
||||
assert.NotEmpty(t, inc.KillChainPhase)
|
||||
assert.NotEmpty(t, inc.MITREMapping)
|
||||
}
|
||||
}
|
||||
|
||||
// --- E2E: Sensor Authentication Flow ---
|
||||
|
||||
func TestE2E_SensorAuth_FullFlow(t *testing.T) {
|
||||
svc := newTestServiceWithLogger(t)
|
||||
|
||||
// Configure sensor keys.
|
||||
svc.SetSensorKeys(map[string]string{
|
||||
"sensor-auth-1": "secret-key-1",
|
||||
"sensor-auth-2": "secret-key-2",
|
||||
})
|
||||
|
||||
// Valid auth — should succeed.
|
||||
evt := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityLow, "test", "auth test")
|
||||
evt.SensorID = "sensor-auth-1"
|
||||
evt.SensorKey = "secret-key-1"
|
||||
id, _, err := svc.IngestEvent(evt)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, id)
|
||||
|
||||
// Invalid key — should fail.
|
||||
evt2 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityLow, "test", "bad key")
|
||||
evt2.SensorID = "sensor-auth-1"
|
||||
evt2.SensorKey = "wrong-key"
|
||||
_, _, err = svc.IngestEvent(evt2)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "auth")
|
||||
|
||||
// Missing SensorID — should fail (S-1 fix).
|
||||
evt3 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityLow, "test", "no sensor id")
|
||||
_, _, err = svc.IngestEvent(evt3)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "sensor_id required")
|
||||
|
||||
// Unknown sensor — should fail.
|
||||
evt4 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityLow, "test", "unknown sensor")
|
||||
evt4.SensorID = "sensor-unknown"
|
||||
evt4.SensorKey = "whatever"
|
||||
_, _, err = svc.IngestEvent(evt4)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "auth")
|
||||
}
|
||||
|
||||
// --- E2E: Drain Mode ---
|
||||
|
||||
func TestE2E_DrainMode_RejectsNewEvents(t *testing.T) {
|
||||
svc := newTestServiceWithLogger(t)
|
||||
|
||||
// Ingest works before drain.
|
||||
evt := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityLow, "test", "pre-drain")
|
||||
evt.SensorID = "sensor-drain"
|
||||
_, _, err := svc.IngestEvent(evt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Activate drain mode.
|
||||
svc.Drain()
|
||||
assert.True(t, svc.IsDraining())
|
||||
|
||||
// New events should be rejected.
|
||||
evt2 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityLow, "test", "during-drain")
|
||||
evt2.SensorID = "sensor-drain"
|
||||
_, _, err = svc.IngestEvent(evt2)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "draining")
|
||||
|
||||
// Resume.
|
||||
svc.Resume()
|
||||
assert.False(t, svc.IsDraining())
|
||||
|
||||
// Events should work again.
|
||||
evt3 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityLow, "test", "post-drain")
|
||||
evt3.SensorID = "sensor-drain"
|
||||
_, _, err = svc.IngestEvent(evt3)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// --- E2E: Webhook Delivery ---
|
||||
|
||||
func TestE2E_WebhookFiredOnIncident(t *testing.T) {
|
||||
svc := newTestServiceWithLogger(t)
|
||||
|
||||
// Set up a test webhook server.
|
||||
var mu sync.Mutex
|
||||
var received []string
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
received = append(received, r.URL.Path)
|
||||
mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
svc.SetWebhookConfig(WebhookConfig{
|
||||
Endpoints: []string{ts.URL + "/webhook"},
|
||||
MaxRetries: 1,
|
||||
TimeoutSec: 5,
|
||||
})
|
||||
|
||||
// Trigger an incident via correlation.
|
||||
evt1 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityHigh, "jailbreak", "jailbreak e2e")
|
||||
evt1.SensorID = "sensor-wh"
|
||||
svc.IngestEvent(evt1)
|
||||
|
||||
evt2 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityCritical, "tool_abuse", "tool abuse e2e")
|
||||
evt2.SensorID = "sensor-wh"
|
||||
_, inc, err := svc.IngestEvent(evt2)
|
||||
require.NoError(t, err)
|
||||
|
||||
if inc != nil {
|
||||
// Give the async webhook goroutine time to fire.
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
assert.GreaterOrEqual(t, len(received), 1, "webhook should have been called")
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// --- E2E: Verdict Flow ---
|
||||
|
||||
func TestE2E_VerdictFlow(t *testing.T) {
|
||||
svc := newTestServiceWithLogger(t)
|
||||
|
||||
// Create an incident via correlation.
|
||||
evt1 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityHigh, "jailbreak", "verdict test 1")
|
||||
evt1.SensorID = "sensor-vd"
|
||||
svc.IngestEvent(evt1)
|
||||
|
||||
evt2 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityCritical, "tool_abuse", "verdict test 2")
|
||||
evt2.SensorID = "sensor-vd"
|
||||
_, inc, _ := svc.IngestEvent(evt2)
|
||||
|
||||
if inc == nil {
|
||||
t.Skip("no incident created — correlation rules may not match with current sliding window state")
|
||||
}
|
||||
|
||||
// Verify initial status is OPEN.
|
||||
got, err := svc.GetIncident(inc.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, domsoc.StatusOpen, got.Status)
|
||||
|
||||
// Update to INVESTIGATING.
|
||||
err = svc.UpdateVerdict(inc.ID, domsoc.StatusInvestigating)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, _ = svc.GetIncident(inc.ID)
|
||||
assert.Equal(t, domsoc.StatusInvestigating, got.Status)
|
||||
|
||||
// Update to RESOLVED.
|
||||
err = svc.UpdateVerdict(inc.ID, domsoc.StatusResolved)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, _ = svc.GetIncident(inc.ID)
|
||||
assert.Equal(t, domsoc.StatusResolved, got.Status)
|
||||
}
|
||||
|
||||
// --- E2E: Analytics Report ---
|
||||
|
||||
func TestE2E_AnalyticsReport(t *testing.T) {
|
||||
svc := newTestServiceWithLogger(t)
|
||||
|
||||
// Ingest several events.
|
||||
categories := []string{"jailbreak", "injection", "exfiltration", "auth_bypass", "tool_abuse"}
|
||||
for i, cat := range categories {
|
||||
evt := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityHigh, cat, fmt.Sprintf("analytics test %d", i))
|
||||
evt.SensorID = "sensor-analytics"
|
||||
svc.IngestEvent(evt)
|
||||
}
|
||||
|
||||
report, err := svc.Analytics(24)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, report)
|
||||
assert.Greater(t, len(report.TopCategories), 0)
|
||||
assert.Greater(t, len(report.TopSources), 0)
|
||||
assert.GreaterOrEqual(t, report.EventsPerHour, float64(0))
|
||||
}
|
||||
|
||||
// --- E2E: Multi-Sensor Concurrent Ingest ---
|
||||
|
||||
func TestE2E_ConcurrentIngest(t *testing.T) {
|
||||
svc := newTestServiceWithLogger(t)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errors := make([]error, 0)
|
||||
var mu sync.Mutex
|
||||
|
||||
// 10 sensors × 10 events each = 100 concurrent ingests.
|
||||
for s := 0; s < 10; s++ {
|
||||
wg.Add(1)
|
||||
go func(sensorNum int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 10; i++ {
|
||||
evt := domsoc.NewSOCEvent(
|
||||
domsoc.SourceSentinelCore,
|
||||
domsoc.SeverityLow,
|
||||
"test",
|
||||
fmt.Sprintf("concurrent sensor-%d event-%d", sensorNum, i),
|
||||
)
|
||||
evt.SensorID = fmt.Sprintf("sensor-conc-%d", sensorNum)
|
||||
_, _, err := svc.IngestEvent(evt)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
errors = append(errors, err)
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
}(s)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Some events may be rate-limited (100 events/sec per sensor),
|
||||
// but there should be no panics or data corruption.
|
||||
dash, err := svc.Dashboard()
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, dash.TotalEvents, 0, "at least some events should have been ingested")
|
||||
}
|
||||
|
||||
// --- E2E: Lattice TSA Chain Violation (SOC-CR-012) ---
|
||||
|
||||
func TestE2E_TSAChainViolation(t *testing.T) {
|
||||
svc := newTestServiceWithLogger(t)
|
||||
|
||||
// SOC-CR-012 requires: auth_bypass → tool_abuse → exfiltration within 15 min.
|
||||
events := []struct {
|
||||
category string
|
||||
severity domsoc.EventSeverity
|
||||
}{
|
||||
{"auth_bypass", domsoc.SeverityHigh},
|
||||
{"tool_abuse", domsoc.SeverityHigh},
|
||||
{"exfiltration", domsoc.SeverityCritical},
|
||||
}
|
||||
|
||||
var lastInc *domsoc.Incident
|
||||
for _, e := range events {
|
||||
evt := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, e.severity, e.category, "TSA chain test: "+e.category)
|
||||
evt.SensorID = "sensor-tsa"
|
||||
_, inc, err := svc.IngestEvent(evt)
|
||||
require.NoError(t, err)
|
||||
if inc != nil {
|
||||
lastInc = inc
|
||||
}
|
||||
}
|
||||
|
||||
// The TSA chain (auth_bypass + tool_abuse + exfiltration) should trigger
|
||||
// SOC-CR-012 or another matching rule.
|
||||
require.NotNil(t, lastInc, "TSA chain (auth_bypass → tool_abuse → exfiltration) should create an incident")
|
||||
assert.Equal(t, domsoc.SeverityCritical, lastInc.Severity)
|
||||
assert.NotEmpty(t, lastInc.MITREMapping)
|
||||
|
||||
// Verify incident is persisted.
|
||||
got, err := svc.GetIncident(lastInc.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, lastInc.ID, got.ID)
|
||||
}
|
||||
|
||||
// --- E2E: Zero-G Mode Excludes Playbook Auto-Response ---
|
||||
|
||||
func TestE2E_ZeroGExcludedFromAutoResponse(t *testing.T) {
|
||||
svc := newTestServiceWithLogger(t)
|
||||
|
||||
// Set up a test webhook server to track playbook webhook notifications.
|
||||
var mu sync.Mutex
|
||||
var webhookCalls int
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
webhookCalls++
|
||||
mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
svc.SetWebhookConfig(WebhookConfig{
|
||||
Endpoints: []string{ts.URL + "/webhook"},
|
||||
MaxRetries: 1,
|
||||
TimeoutSec: 5,
|
||||
})
|
||||
|
||||
// Ingest jailbreak + tool_abuse with ZeroGMode=true.
|
||||
// This should trigger correlation (incident created) but NOT playbooks.
|
||||
evt1 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityHigh, "jailbreak", "zero-g jailbreak test")
|
||||
evt1.SensorID = "sensor-zg"
|
||||
evt1.ZeroGMode = true
|
||||
_, _, err := svc.IngestEvent(evt1)
|
||||
require.NoError(t, err)
|
||||
|
||||
evt2 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityCritical, "tool_abuse", "zero-g tool abuse test")
|
||||
evt2.SensorID = "sensor-zg"
|
||||
evt2.ZeroGMode = true
|
||||
_, inc, err := svc.IngestEvent(evt2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Correlation should still run — incident should be created.
|
||||
if inc != nil {
|
||||
assert.Equal(t, domsoc.SeverityCritical, inc.Severity)
|
||||
|
||||
// Wait for any async webhook goroutines.
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Webhook should NOT have been called (playbook skipped for Zero-G).
|
||||
mu.Lock()
|
||||
assert.Equal(t, 0, webhookCalls, "webhooks should NOT fire for Zero-G events — playbook must be skipped")
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
// Verify decision log records the PLAYBOOK_SKIPPED:ZERO_G entry.
|
||||
logPath := svc.DecisionLogPath()
|
||||
if logPath != "" {
|
||||
valid, broken, err := audit.VerifyChainFromFile(logPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, broken, "decision chain should be intact")
|
||||
assert.Greater(t, valid, 0, "should have decision entries")
|
||||
}
|
||||
}
|
||||
|
||||
// --- E2E: Decision Logger Tamper Detection ---
|
||||
|
||||
func TestE2E_DecisionLoggerTampering(t *testing.T) {
|
||||
svc := newTestServiceWithLogger(t)
|
||||
|
||||
// Ingest several events to build up a decision chain.
|
||||
for i := 0; i < 10; i++ {
|
||||
evt := domsoc.NewSOCEvent(
|
||||
domsoc.SourceSentinelCore,
|
||||
domsoc.SeverityLow,
|
||||
"test",
|
||||
fmt.Sprintf("tamper test event %d", i),
|
||||
)
|
||||
evt.SensorID = "sensor-tamper"
|
||||
_, _, err := svc.IngestEvent(evt)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Step 1: Verify chain is valid.
|
||||
logPath := svc.DecisionLogPath()
|
||||
require.NotEmpty(t, logPath, "decision log path should be set")
|
||||
|
||||
validCount, brokenLine, err := audit.VerifyChainFromFile(logPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, brokenLine, "chain should be intact before tampering")
|
||||
assert.GreaterOrEqual(t, validCount, 10, "should have at least 10 decision entries")
|
||||
|
||||
// Step 2: Tamper with the log file — modify a line mid-chain.
|
||||
data, err := os.ReadFile(logPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
lines := bytes.Split(data, []byte("\n"))
|
||||
if len(lines) > 5 {
|
||||
// Corrupt line 5 by altering content.
|
||||
lines[4] = []byte("TAMPERED|2026-01-01T00:00:00Z|SOC|FAKE|fake_reason|0000000000")
|
||||
|
||||
err = os.WriteFile(logPath, bytes.Join(lines, []byte("\n")), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Step 3: Verify chain detects the tamper.
|
||||
_, brokenLine2, err2 := audit.VerifyChainFromFile(logPath)
|
||||
require.NoError(t, err2)
|
||||
assert.Greater(t, brokenLine2, 0, "chain should detect tampering — broken line reported")
|
||||
}
|
||||
}
|
||||
|
||||
// --- E2E: Cross-Sensor Session Correlation (SOC-CR-011) ---
|
||||
|
||||
func TestE2E_CrossSensorSessionCorrelation(t *testing.T) {
|
||||
svc := newTestServiceWithLogger(t)
|
||||
|
||||
// SOC-CR-011 requires 3+ events from different sensors with same session_id.
|
||||
sessionID := "session-xsensor-e2e-001"
|
||||
|
||||
sources := []struct {
|
||||
source domsoc.EventSource
|
||||
sensor string
|
||||
category string
|
||||
}{
|
||||
{domsoc.SourceShield, "sensor-shield-1", "auth_bypass"},
|
||||
{domsoc.SourceSentinelCore, "sensor-core-1", "jailbreak"},
|
||||
{domsoc.SourceImmune, "sensor-immune-1", "exfiltration"},
|
||||
}
|
||||
|
||||
var lastInc *domsoc.Incident
|
||||
for _, s := range sources {
|
||||
evt := domsoc.NewSOCEvent(s.source, domsoc.SeverityHigh, s.category, "cross-sensor test: "+s.category)
|
||||
evt.SensorID = s.sensor
|
||||
evt.SessionID = sessionID
|
||||
_, inc, err := svc.IngestEvent(evt)
|
||||
require.NoError(t, err)
|
||||
if inc != nil {
|
||||
lastInc = inc
|
||||
}
|
||||
}
|
||||
|
||||
// After 3 events from different sensors/sources with same session_id,
|
||||
// at least one correlation rule should have matched.
|
||||
require.NotNil(t, lastInc, "cross-sensor session attack (3 sources, same session_id) should create incident")
|
||||
assert.NotEmpty(t, lastInc.ID)
|
||||
assert.NotEmpty(t, lastInc.Events, "incident should reference triggering events")
|
||||
}
|
||||
|
||||
// --- E2E: Crescendo Escalation (SOC-CR-015) ---
|
||||
|
||||
func TestE2E_CrescendoEscalation(t *testing.T) {
|
||||
svc := newTestServiceWithLogger(t)
|
||||
|
||||
// SOC-CR-015: 3+ jailbreak events with ascending severity within 15 min.
|
||||
severities := []domsoc.EventSeverity{
|
||||
domsoc.SeverityLow,
|
||||
domsoc.SeverityMedium,
|
||||
domsoc.SeverityHigh,
|
||||
}
|
||||
|
||||
var lastInc *domsoc.Incident
|
||||
for i, sev := range severities {
|
||||
evt := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, sev, "jailbreak",
|
||||
fmt.Sprintf("crescendo jailbreak attempt %d", i+1))
|
||||
evt.SensorID = "sensor-crescendo"
|
||||
_, inc, err := svc.IngestEvent(evt)
|
||||
require.NoError(t, err)
|
||||
if inc != nil {
|
||||
lastInc = inc
|
||||
}
|
||||
}
|
||||
|
||||
// The ascending severity pattern (LOW→MEDIUM→HIGH) should trigger SOC-CR-015.
|
||||
require.NotNil(t, lastInc, "crescendo pattern (LOW→MEDIUM→HIGH jailbreaks) should create incident")
|
||||
assert.Equal(t, domsoc.SeverityCritical, lastInc.Severity)
|
||||
assert.Contains(t, lastInc.MITREMapping, "T1059")
|
||||
}
|
||||
|
||||
100
internal/application/soc/ingest_bench_test.go
Normal file
100
internal/application/soc/ingest_bench_test.go
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
package soc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
domsoc "github.com/syntrex/gomcp/internal/domain/soc"
|
||||
"github.com/syntrex/gomcp/internal/infrastructure/audit"
|
||||
"github.com/syntrex/gomcp/internal/infrastructure/sqlite"
|
||||
)
|
||||
|
||||
// newBenchService creates a minimal SOC service for benchmarking.
|
||||
// Disables rate limiting to measure raw pipeline throughput.
|
||||
func newBenchService(b *testing.B) *Service {
|
||||
b.Helper()
|
||||
|
||||
tmpDir := b.TempDir()
|
||||
dbPath := tmpDir + "/bench.db"
|
||||
|
||||
db, err := sqlite.Open(dbPath)
|
||||
require.NoError(b, err)
|
||||
b.Cleanup(func() { db.Close() })
|
||||
|
||||
repo, err := sqlite.NewSOCRepo(db)
|
||||
require.NoError(b, err)
|
||||
|
||||
logger, err := audit.NewDecisionLogger(tmpDir)
|
||||
require.NoError(b, err)
|
||||
b.Cleanup(func() { logger.Close() })
|
||||
|
||||
svc := NewService(repo, logger)
|
||||
svc.DisableRateLimit() // benchmarks measure throughput, not rate limiting
|
||||
return svc
|
||||
}
|
||||
|
||||
// BenchmarkIngestEvent measures single-event pipeline throughput.
|
||||
func BenchmarkIngestEvent(b *testing.B) {
|
||||
svc := newBenchService(b)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
event := domsoc.NewSOCEvent(domsoc.SourceShield, domsoc.SeverityMedium, "injection",
|
||||
fmt.Sprintf("Bench event #%d", i))
|
||||
event.ID = fmt.Sprintf("bench-evt-%d", i)
|
||||
_, _, err := svc.IngestEvent(event)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkIngestEvent_WithCorrelation measures pipeline with correlation active.
|
||||
// Pre-loads events to trigger correlation matching.
|
||||
func BenchmarkIngestEvent_WithCorrelation(b *testing.B) {
|
||||
svc := newBenchService(b)
|
||||
|
||||
// Pre-load events to make correlation rules meaningful.
|
||||
for i := 0; i < 50; i++ {
|
||||
event := domsoc.NewSOCEvent(domsoc.SourceShield, domsoc.SeverityHigh, "jailbreak",
|
||||
fmt.Sprintf("Pre-load jailbreak #%d", i))
|
||||
event.ID = fmt.Sprintf("preload-%d", i)
|
||||
svc.IngestEvent(event)
|
||||
time.Sleep(time.Microsecond)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
event := domsoc.NewSOCEvent(domsoc.SourceShield, domsoc.SeverityHigh, "jailbreak",
|
||||
fmt.Sprintf("Corr bench event #%d", i))
|
||||
event.ID = fmt.Sprintf("bench-corr-%d", i)
|
||||
_, _, err := svc.IngestEvent(event)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkIngestEvent_Parallel measures concurrent ingest throughput.
|
||||
func BenchmarkIngestEvent_Parallel(b *testing.B) {
|
||||
svc := newBenchService(b)
|
||||
var counter atomic.Int64
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
n := counter.Add(1)
|
||||
event := domsoc.NewSOCEvent(domsoc.SourceShield, domsoc.SeverityLow, "jailbreak",
|
||||
fmt.Sprintf("Parallel bench #%d", n))
|
||||
event.ID = fmt.Sprintf("bench-par-%d", n)
|
||||
_, _, err := svc.IngestEvent(event)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
153
internal/application/soc/load_test.go
Normal file
153
internal/application/soc/load_test.go
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
package soc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
domsoc "github.com/syntrex/gomcp/internal/domain/soc"
|
||||
"github.com/syntrex/gomcp/internal/infrastructure/audit"
|
||||
"github.com/syntrex/gomcp/internal/infrastructure/sqlite"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestLoadTest_SustainedThroughput measures SOC pipeline throughput and latency
|
||||
// under sustained concurrent load. Reports p50/p95/p99 latencies and events/sec.
|
||||
func TestLoadTest_SustainedThroughput(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping load test in short mode")
|
||||
}
|
||||
|
||||
// Setup service with file-based SQLite for concurrency safety.
|
||||
tmpDir := t.TempDir()
|
||||
db, err := sqlite.Open(tmpDir + "/loadtest.db")
|
||||
require.NoError(t, err)
|
||||
|
||||
repo, err := sqlite.NewSOCRepo(db)
|
||||
require.NoError(t, err)
|
||||
|
||||
logger, err := audit.NewDecisionLogger(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
logger.Close()
|
||||
db.Close()
|
||||
})
|
||||
|
||||
svc := NewService(repo, logger)
|
||||
svc.DisableRateLimit() // bypass rate limiter for raw throughput
|
||||
|
||||
// Load test parameters.
|
||||
const (
|
||||
numWorkers = 16
|
||||
eventsPerWkr = 200
|
||||
totalEvents = numWorkers * eventsPerWkr
|
||||
)
|
||||
|
||||
categories := []string{"jailbreak", "injection", "exfiltration", "auth_bypass", "tool_abuse"}
|
||||
sources := []domsoc.EventSource{domsoc.SourceSentinelCore, domsoc.SourceShield, domsoc.SourceGoMCP}
|
||||
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
latencies = make([]time.Duration, totalEvents)
|
||||
errors int64
|
||||
incidents int64
|
||||
)
|
||||
|
||||
start := time.Now()
|
||||
|
||||
for w := 0; w < numWorkers; w++ {
|
||||
wg.Add(1)
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < eventsPerWkr; i++ {
|
||||
idx := workerID*eventsPerWkr + i
|
||||
evt := domsoc.NewSOCEvent(
|
||||
sources[idx%len(sources)],
|
||||
domsoc.SeverityHigh,
|
||||
categories[idx%len(categories)],
|
||||
fmt.Sprintf("load-test w%d-e%d", workerID, i),
|
||||
)
|
||||
evt.SensorID = fmt.Sprintf("load-sensor-%d", workerID)
|
||||
|
||||
t0 := time.Now()
|
||||
_, inc, err := svc.IngestEvent(evt)
|
||||
latencies[idx] = time.Since(t0)
|
||||
|
||||
if err != nil {
|
||||
atomic.AddInt64(&errors, 1)
|
||||
}
|
||||
if inc != nil {
|
||||
atomic.AddInt64(&incidents, 1)
|
||||
}
|
||||
}
|
||||
}(w)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
totalDuration := time.Since(start)
|
||||
|
||||
// Compute latency percentiles.
|
||||
sort.Slice(latencies, func(i, j int) bool { return latencies[i] < latencies[j] })
|
||||
|
||||
p50 := percentile(latencies, 50)
|
||||
p95 := percentile(latencies, 95)
|
||||
p99 := percentile(latencies, 99)
|
||||
mean := meanDuration(latencies)
|
||||
eventsPerSec := float64(totalEvents) / totalDuration.Seconds()
|
||||
|
||||
// Report results.
|
||||
t.Logf("═══════════════════════════════════════════════")
|
||||
t.Logf(" SENTINEL SOC Load Test Results")
|
||||
t.Logf("═══════════════════════════════════════════════")
|
||||
t.Logf(" Workers: %d", numWorkers)
|
||||
t.Logf(" Events/worker: %d", eventsPerWkr)
|
||||
t.Logf(" Total events: %d", totalEvents)
|
||||
t.Logf(" Duration: %s", totalDuration.Round(time.Millisecond))
|
||||
t.Logf(" Throughput: %.0f events/sec", eventsPerSec)
|
||||
t.Logf("───────────────────────────────────────────────")
|
||||
t.Logf(" Mean: %s", mean.Round(time.Microsecond))
|
||||
t.Logf(" P50: %s", p50.Round(time.Microsecond))
|
||||
t.Logf(" P95: %s", p95.Round(time.Microsecond))
|
||||
t.Logf(" P99: %s", p99.Round(time.Microsecond))
|
||||
t.Logf(" Min: %s", latencies[0].Round(time.Microsecond))
|
||||
t.Logf(" Max: %s", latencies[len(latencies)-1].Round(time.Microsecond))
|
||||
t.Logf("───────────────────────────────────────────────")
|
||||
t.Logf(" Errors: %d (%.1f%%)", errors, float64(errors)/float64(totalEvents)*100)
|
||||
t.Logf(" Incidents: %d", incidents)
|
||||
t.Logf("═══════════════════════════════════════════════")
|
||||
|
||||
// Assertions: basic sanity checks.
|
||||
require.Less(t, float64(errors)/float64(totalEvents), 0.05, "error rate should be < 5%%")
|
||||
require.Greater(t, eventsPerSec, float64(100), "should sustain > 100 events/sec")
|
||||
}
|
||||
|
||||
func percentile(sorted []time.Duration, p int) time.Duration {
|
||||
if len(sorted) == 0 {
|
||||
return 0
|
||||
}
|
||||
idx := int(math.Ceil(float64(p)/100.0*float64(len(sorted)))) - 1
|
||||
if idx < 0 {
|
||||
idx = 0
|
||||
}
|
||||
if idx >= len(sorted) {
|
||||
idx = len(sorted) - 1
|
||||
}
|
||||
return sorted[idx]
|
||||
}
|
||||
|
||||
func meanDuration(ds []time.Duration) time.Duration {
|
||||
if len(ds) == 0 {
|
||||
return 0
|
||||
}
|
||||
var total time.Duration
|
||||
for _, d := range ds {
|
||||
total += d
|
||||
}
|
||||
return total / time.Duration(len(ds))
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -143,7 +143,7 @@ func TestRunPlaybook_IncidentNotFound(t *testing.T) {
|
|||
svc := newTestService(t)
|
||||
|
||||
// Use a valid playbook ID from defaults.
|
||||
_, err := svc.RunPlaybook("pb-auto-block-jailbreak", "nonexistent-inc")
|
||||
_, err := svc.RunPlaybook("pb-block-jailbreak", "nonexistent-inc")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "incident not found")
|
||||
}
|
||||
|
|
|
|||
255
internal/application/soc/stix_feed.go
Normal file
255
internal/application/soc/stix_feed.go
Normal file
|
|
@ -0,0 +1,255 @@
|
|||
package soc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// STIXBundle represents a STIX 2.1 bundle (simplified).
|
||||
type STIXBundle struct {
|
||||
Type string `json:"type"` // "bundle"
|
||||
ID string `json:"id"`
|
||||
Objects []STIXObject `json:"objects"`
|
||||
}
|
||||
|
||||
// STIXObject represents a generic STIX 2.1 object.
|
||||
type STIXObject struct {
|
||||
Type string `json:"type"` // indicator, malware, attack-pattern, etc.
|
||||
ID string `json:"id"`
|
||||
Created time.Time `json:"created"`
|
||||
Modified time.Time `json:"modified"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Pattern string `json:"pattern,omitempty"` // STIX pattern (indicators)
|
||||
PatternType string `json:"pattern_type,omitempty"` // stix, pcre, sigma
|
||||
ValidFrom time.Time `json:"valid_from,omitempty"`
|
||||
Labels []string `json:"labels,omitempty"`
|
||||
// Kill chain phases for attack-pattern objects.
|
||||
KillChainPhases []struct {
|
||||
KillChainName string `json:"kill_chain_name"`
|
||||
PhaseName string `json:"phase_name"`
|
||||
} `json:"kill_chain_phases,omitempty"`
|
||||
// External references (CVE, etc.)
|
||||
ExternalReferences []struct {
|
||||
SourceName string `json:"source_name"`
|
||||
ExternalID string `json:"external_id,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
} `json:"external_references,omitempty"`
|
||||
}
|
||||
|
||||
// STIXFeedConfig configures automatic STIX feed polling.
|
||||
type STIXFeedConfig struct {
|
||||
Name string `json:"name"` // Feed name (e.g., "OTX", "MISP")
|
||||
URL string `json:"url"` // TAXII or HTTP feed URL
|
||||
APIKey string `json:"api_key"` // Authentication key
|
||||
Headers map[string]string `json:"headers"` // Additional headers
|
||||
Interval time.Duration `json:"interval"` // Poll interval (default: 1h)
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// FeedSync syncs IOCs from STIX/TAXII feeds into the ThreatIntelStore.
|
||||
type FeedSync struct {
|
||||
feeds []STIXFeedConfig
|
||||
store *ThreatIntelStore
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewFeedSync creates a feed synchronizer.
|
||||
func NewFeedSync(store *ThreatIntelStore, feeds []STIXFeedConfig) *FeedSync {
|
||||
return &FeedSync{
|
||||
feeds: feeds,
|
||||
store: store,
|
||||
client: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins polling all enabled feeds in the background.
|
||||
func (f *FeedSync) Start(done <-chan struct{}) {
|
||||
for _, feed := range f.feeds {
|
||||
if !feed.Enabled {
|
||||
continue
|
||||
}
|
||||
go f.pollFeed(feed, done)
|
||||
}
|
||||
}
|
||||
|
||||
// pollFeed periodically fetches and processes a single STIX feed.
|
||||
func (f *FeedSync) pollFeed(feed STIXFeedConfig, done <-chan struct{}) {
|
||||
interval := feed.Interval
|
||||
if interval == 0 {
|
||||
interval = time.Hour
|
||||
}
|
||||
|
||||
slog.Info("stix feed started", "feed", feed.Name, "url", feed.URL, "interval", interval)
|
||||
|
||||
// Initial fetch.
|
||||
f.fetchFeed(feed)
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
slog.Info("stix feed stopped", "feed", feed.Name)
|
||||
return
|
||||
case <-ticker.C:
|
||||
f.fetchFeed(feed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fetchFeed performs a single HTTP GET and processes the STIX bundle.
|
||||
func (f *FeedSync) fetchFeed(feed STIXFeedConfig) {
|
||||
req, err := http.NewRequest(http.MethodGet, feed.URL, nil)
|
||||
if err != nil {
|
||||
slog.Error("stix feed: request error", "feed", feed.Name, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/stix+json;version=2.1")
|
||||
if feed.APIKey != "" {
|
||||
req.Header.Set("X-OTX-API-KEY", feed.APIKey)
|
||||
}
|
||||
for k, v := range feed.Headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := f.client.Do(req)
|
||||
if err != nil {
|
||||
slog.Error("stix feed: fetch error", "feed", feed.Name, "error", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
slog.Warn("stix feed: non-200 response", "feed", feed.Name, "status", resp.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
var bundle STIXBundle
|
||||
if err := json.NewDecoder(resp.Body).Decode(&bundle); err != nil {
|
||||
slog.Error("stix feed: decode error", "feed", feed.Name, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
imported := f.processBundle(feed.Name, bundle)
|
||||
slog.Info("stix feed synced",
|
||||
"feed", feed.Name,
|
||||
"objects", len(bundle.Objects),
|
||||
"iocs_imported", imported,
|
||||
)
|
||||
}
|
||||
|
||||
// processBundle extracts IOCs from STIX indicators and adds to the store.
|
||||
func (f *FeedSync) processBundle(feedName string, bundle STIXBundle) int {
|
||||
imported := 0
|
||||
for _, obj := range bundle.Objects {
|
||||
if obj.Type != "indicator" || obj.Pattern == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
ioc := stixPatternToIOC(obj)
|
||||
if ioc == nil {
|
||||
continue
|
||||
}
|
||||
ioc.Source = feedName
|
||||
ioc.Tags = obj.Labels
|
||||
|
||||
f.store.AddIOC(*ioc)
|
||||
imported++
|
||||
}
|
||||
return imported
|
||||
}
|
||||
|
||||
// stixPatternToIOC converts a STIX indicator pattern to our IOC format.
|
||||
// Supports: [file:hashes.'SHA-256' = '...'], [ipv4-addr:value = '...'],
|
||||
// [domain-name:value = '...'], [url:value = '...']
|
||||
func stixPatternToIOC(obj STIXObject) *IOC {
|
||||
pattern := obj.Pattern
|
||||
now := obj.Modified
|
||||
if now.IsZero() {
|
||||
now = obj.Created
|
||||
}
|
||||
ioc := &IOC{
|
||||
Value: "",
|
||||
Severity: "medium",
|
||||
FirstSeen: now,
|
||||
LastSeen: now,
|
||||
Confidence: 0.7,
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.Contains(pattern, "file:hashes"):
|
||||
ioc.Type = IOCTypeHash
|
||||
ioc.Value = extractSTIXValue(pattern)
|
||||
case strings.Contains(pattern, "ipv4-addr:value"):
|
||||
ioc.Type = IOCTypeIP
|
||||
ioc.Value = extractSTIXValue(pattern)
|
||||
case strings.Contains(pattern, "domain-name:value"):
|
||||
ioc.Type = IOCTypeDomain
|
||||
ioc.Value = extractSTIXValue(pattern)
|
||||
case strings.Contains(pattern, "url:value"):
|
||||
ioc.Type = IOCTypeURL
|
||||
ioc.Value = extractSTIXValue(pattern)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
if ioc.Value == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Derive severity from STIX labels.
|
||||
for _, label := range obj.Labels {
|
||||
switch {
|
||||
case strings.Contains(label, "anomalous-activity"):
|
||||
ioc.Severity = "low"
|
||||
case strings.Contains(label, "malicious-activity"):
|
||||
ioc.Severity = "critical"
|
||||
case strings.Contains(label, "attribution"):
|
||||
ioc.Severity = "high"
|
||||
}
|
||||
}
|
||||
|
||||
return ioc
|
||||
}
|
||||
|
||||
// extractSTIXValue pulls the quoted value from a STIX pattern like:
|
||||
// [ipv4-addr:value = '192.168.1.1']
|
||||
// [file:hashes.'SHA-256' = 'e3b0c44...']
|
||||
func extractSTIXValue(pattern string) string {
|
||||
// Anchor on "= '" to skip any earlier quotes (e.g., hashes.'SHA-256').
|
||||
eqIdx := strings.Index(pattern, "= '")
|
||||
if eqIdx < 0 {
|
||||
return ""
|
||||
}
|
||||
start := eqIdx + 3 // skip "= '"
|
||||
end := strings.Index(pattern[start:], "'")
|
||||
if end < 0 {
|
||||
return ""
|
||||
}
|
||||
return pattern[start : start+end]
|
||||
}
|
||||
|
||||
// DefaultOTXFeed returns a pre-configured AlienVault OTX feed config.
|
||||
func DefaultOTXFeed(apiKey string) STIXFeedConfig {
|
||||
return STIXFeedConfig{
|
||||
Name: "AlienVault OTX",
|
||||
URL: "https://otx.alienvault.com/api/v1/pulses/subscribed",
|
||||
APIKey: apiKey,
|
||||
Interval: time.Hour,
|
||||
Enabled: apiKey != "",
|
||||
Headers: map[string]string{
|
||||
"X-OTX-API-KEY": apiKey,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// IOC type is defined in threat_intel.go — this file uses it directly.
|
||||
137
internal/application/soc/stix_feed_test.go
Normal file
137
internal/application/soc/stix_feed_test.go
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
package soc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- stixPatternToIOC ---
|
||||
|
||||
func TestSTIXPatternToIOC_IPv4(t *testing.T) {
|
||||
obj := STIXObject{
|
||||
Type: "indicator",
|
||||
Pattern: "[ipv4-addr:value = '192.168.1.1']",
|
||||
Modified: time.Now(),
|
||||
}
|
||||
ioc := stixPatternToIOC(obj)
|
||||
require.NotNil(t, ioc, "should parse IPv4 pattern")
|
||||
assert.Equal(t, IOCTypeIP, ioc.Type)
|
||||
assert.Equal(t, "192.168.1.1", ioc.Value)
|
||||
assert.Equal(t, "medium", ioc.Severity)
|
||||
assert.False(t, ioc.FirstSeen.IsZero(), "FirstSeen must be set")
|
||||
}
|
||||
|
||||
func TestSTIXPatternToIOC_Hash(t *testing.T) {
|
||||
hash := "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
|
||||
obj := STIXObject{
|
||||
Type: "indicator",
|
||||
Pattern: "[file:hashes.'SHA-256' = '" + hash + "']",
|
||||
Modified: time.Now(),
|
||||
Labels: []string{"malicious-activity"},
|
||||
}
|
||||
ioc := stixPatternToIOC(obj)
|
||||
require.NotNil(t, ioc, "should parse hash pattern")
|
||||
assert.Equal(t, IOCTypeHash, ioc.Type)
|
||||
assert.Equal(t, hash, ioc.Value)
|
||||
assert.Equal(t, "critical", ioc.Severity, "malicious-activity label → critical")
|
||||
}
|
||||
|
||||
func TestSTIXPatternToIOC_Domain(t *testing.T) {
|
||||
obj := STIXObject{
|
||||
Type: "indicator",
|
||||
Pattern: "[domain-name:value = 'evil.example.com']",
|
||||
Modified: time.Now(),
|
||||
Labels: []string{"attribution"},
|
||||
}
|
||||
ioc := stixPatternToIOC(obj)
|
||||
require.NotNil(t, ioc)
|
||||
assert.Equal(t, IOCTypeDomain, ioc.Type)
|
||||
assert.Equal(t, "evil.example.com", ioc.Value)
|
||||
assert.Equal(t, "high", ioc.Severity, "attribution label → high")
|
||||
}
|
||||
|
||||
func TestSTIXPatternToIOC_Unsupported(t *testing.T) {
|
||||
obj := STIXObject{
|
||||
Type: "indicator",
|
||||
Pattern: "[email-addr:value = 'attacker@evil.com']",
|
||||
Modified: time.Now(),
|
||||
}
|
||||
ioc := stixPatternToIOC(obj)
|
||||
assert.Nil(t, ioc, "unsupported pattern type should return nil")
|
||||
}
|
||||
|
||||
func TestSTIXPatternToIOC_FallbackToCreated(t *testing.T) {
|
||||
created := time.Date(2026, 1, 15, 0, 0, 0, 0, time.UTC)
|
||||
obj := STIXObject{
|
||||
Type: "indicator",
|
||||
Pattern: "[ipv4-addr:value = '10.0.0.1']",
|
||||
Created: created,
|
||||
// Modified is zero → should fall back to Created
|
||||
}
|
||||
ioc := stixPatternToIOC(obj)
|
||||
require.NotNil(t, ioc)
|
||||
assert.Equal(t, created, ioc.FirstSeen, "should fall back to Created when Modified is zero")
|
||||
}
|
||||
|
||||
// --- extractSTIXValue ---
|
||||
|
||||
func TestExtractSTIXValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pattern string
|
||||
want string
|
||||
}{
|
||||
{"ipv4", "[ipv4-addr:value = '1.2.3.4']", "1.2.3.4"},
|
||||
{"domain", "[domain-name:value = 'evil.com']", "evil.com"},
|
||||
{"hash", "[file:hashes.'SHA-256' = 'abc123']", "abc123"},
|
||||
{"empty_no_quotes", "[ipv4-addr:value = ]", ""},
|
||||
{"single_quote_only", "'", ""},
|
||||
{"empty_string", "", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractSTIXValue(tt.pattern)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- processBundle ---
|
||||
|
||||
func TestProcessBundle_FiltersNonIndicators(t *testing.T) {
|
||||
store := NewThreatIntelStore()
|
||||
fs := NewFeedSync(store, nil)
|
||||
|
||||
bundle := STIXBundle{
|
||||
Type: "bundle",
|
||||
ID: "bundle--test",
|
||||
Objects: []STIXObject{
|
||||
{Type: "indicator", Pattern: "[ipv4-addr:value = '10.0.0.1']", Modified: time.Now()},
|
||||
{Type: "malware", Name: "BadMalware"}, // should be skipped
|
||||
{Type: "indicator", Pattern: ""}, // empty pattern → skipped
|
||||
{Type: "attack-pattern", Name: "Phish"}, // should be skipped
|
||||
{Type: "indicator", Pattern: "[domain-name:value = 'bad.com']", Modified: time.Now()},
|
||||
},
|
||||
}
|
||||
|
||||
imported := fs.processBundle("test-feed", bundle)
|
||||
assert.Equal(t, 2, imported, "should import only 2 valid indicators")
|
||||
assert.Equal(t, 2, store.TotalIOCs, "store should have 2 IOCs")
|
||||
}
|
||||
|
||||
// --- DefaultOTXFeed ---
|
||||
|
||||
func TestDefaultOTXFeed(t *testing.T) {
|
||||
feed := DefaultOTXFeed("test-key-123")
|
||||
assert.Equal(t, "AlienVault OTX", feed.Name)
|
||||
assert.True(t, feed.Enabled, "should be enabled when key provided")
|
||||
assert.Contains(t, feed.URL, "otx.alienvault.com")
|
||||
assert.Equal(t, time.Hour, feed.Interval)
|
||||
assert.Equal(t, "test-key-123", feed.Headers["X-OTX-API-KEY"])
|
||||
|
||||
disabled := DefaultOTXFeed("")
|
||||
assert.False(t, disabled.Enabled, "should be disabled when key is empty")
|
||||
}
|
||||
|
|
@ -6,8 +6,8 @@ import (
|
|||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"log/slog"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
|
@ -58,9 +58,9 @@ type WebhookNotifier struct {
|
|||
client *http.Client
|
||||
enabled bool
|
||||
|
||||
// Stats
|
||||
Sent int64 `json:"sent"`
|
||||
Failed int64 `json:"failed"`
|
||||
// Stats (unexported — access via Stats() method)
|
||||
sent int64
|
||||
failed int64
|
||||
}
|
||||
|
||||
// NewWebhookNotifier creates a notifier with the given config.
|
||||
|
|
@ -80,23 +80,7 @@ func NewWebhookNotifier(config WebhookConfig) *WebhookNotifier {
|
|||
}
|
||||
}
|
||||
|
||||
// severityRank returns numeric rank for severity comparison.
|
||||
func severityRank(s domsoc.EventSeverity) int {
|
||||
switch s {
|
||||
case domsoc.SeverityCritical:
|
||||
return 5
|
||||
case domsoc.SeverityHigh:
|
||||
return 4
|
||||
case domsoc.SeverityMedium:
|
||||
return 3
|
||||
case domsoc.SeverityLow:
|
||||
return 2
|
||||
case domsoc.SeverityInfo:
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// NotifyIncident sends an incident webhook to all configured endpoints.
|
||||
// Non-blocking: fires goroutines for each endpoint.
|
||||
|
|
@ -105,9 +89,9 @@ func (w *WebhookNotifier) NotifyIncident(eventType string, incident *domsoc.Inci
|
|||
return nil
|
||||
}
|
||||
|
||||
// Severity filter
|
||||
// Severity filter — use domain Rank() method (Q-1 FIX: removed duplicate severityRank).
|
||||
if w.config.MinSeverity != "" {
|
||||
if severityRank(incident.Severity) < severityRank(w.config.MinSeverity) {
|
||||
if incident.Severity.Rank() < w.config.MinSeverity.Rank() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
@ -146,9 +130,9 @@ func (w *WebhookNotifier) NotifyIncident(eventType string, incident *domsoc.Inci
|
|||
w.mu.Lock()
|
||||
for _, r := range results {
|
||||
if r.Success {
|
||||
w.Sent++
|
||||
w.sent++
|
||||
} else {
|
||||
w.Failed++
|
||||
w.failed++
|
||||
}
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
|
@ -213,7 +197,7 @@ func (w *WebhookNotifier) sendWithRetry(url string, body []byte) WebhookResult {
|
|||
result.Error = err.Error()
|
||||
if attempt < w.config.MaxRetries {
|
||||
backoff := time.Duration(1<<uint(attempt)) * 500 * time.Millisecond
|
||||
jitter := time.Duration(rand.Intn(500)) * time.Millisecond
|
||||
jitter := time.Duration(rand.IntN(500)) * time.Millisecond
|
||||
time.Sleep(backoff + jitter)
|
||||
continue
|
||||
}
|
||||
|
|
@ -230,12 +214,12 @@ func (w *WebhookNotifier) sendWithRetry(url string, body []byte) WebhookResult {
|
|||
result.Error = fmt.Sprintf("HTTP %d", resp.StatusCode)
|
||||
if attempt < w.config.MaxRetries {
|
||||
backoff := time.Duration(1<<uint(attempt)) * 500 * time.Millisecond
|
||||
jitter := time.Duration(rand.Intn(500)) * time.Millisecond
|
||||
jitter := time.Duration(rand.IntN(500)) * time.Millisecond
|
||||
time.Sleep(backoff + jitter)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[SOC] webhook failed after %d retries: %s → %s", w.config.MaxRetries, url, result.Error)
|
||||
slog.Error("webhook failed", "retries", w.config.MaxRetries, "url", url, "error", result.Error)
|
||||
return result
|
||||
}
|
||||
|
||||
|
|
@ -243,5 +227,5 @@ func (w *WebhookNotifier) sendWithRetry(url string, body []byte) WebhookResult {
|
|||
func (w *WebhookNotifier) Stats() (sent, failed int64) {
|
||||
w.mu.RLock()
|
||||
defer w.mu.RUnlock()
|
||||
return w.Sent, w.Failed
|
||||
return w.sent, w.failed
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue