Release prep: 54 engines, self-hosted signatures, i18n, dashboard updates

This commit is contained in:
DmitrL-dev 2026-03-23 16:45:40 +10:00
parent 694e32be26
commit 41cbfd6e0a
178 changed files with 36008 additions and 399 deletions

View file

@ -0,0 +1,165 @@
package resilience
import (
"context"
"log/slog"
"runtime"
"sync"
"time"
)
// BehaviorProfile captures the runtime behavior of a component.
type BehaviorProfile struct {
Goroutines int `json:"goroutines"`
HeapAllocMB float64 `json:"heap_alloc_mb"`
HeapObjectsK float64 `json:"heap_objects_k"`
GCPauseMs float64 `json:"gc_pause_ms"`
NumGC uint32 `json:"num_gc"`
FileDescriptors int `json:"file_descriptors,omitempty"`
CustomMetrics map[string]float64 `json:"custom_metrics,omitempty"`
}
// BehavioralAlert is emitted when a behavioral anomaly is detected.
type BehavioralAlert struct {
Component string `json:"component"`
AnomalyType string `json:"anomaly_type"` // goroutine_leak, memory_leak, gc_pressure, etc.
Metric string `json:"metric"`
Current float64 `json:"current"`
Baseline float64 `json:"baseline"`
ZScore float64 `json:"z_score"`
Severity string `json:"severity"`
Timestamp time.Time `json:"timestamp"`
}
// BehavioralAnalyzer provides Go-side runtime behavioral analysis.
// It profiles the current process and compares against learned baselines.
// On Linux, eBPF hooks (immune/resilience_hooks.c) extend this to kernel level.
type BehavioralAnalyzer struct {
mu sync.RWMutex
metricsDB *MetricsDB
alertBus chan BehavioralAlert
interval time.Duration
component string // self component name
logger *slog.Logger
}
// NewBehavioralAnalyzer creates a new behavioral analyzer.
func NewBehavioralAnalyzer(component string, alertBufSize int) *BehavioralAnalyzer {
if alertBufSize <= 0 {
alertBufSize = 50
}
return &BehavioralAnalyzer{
metricsDB: NewMetricsDB(DefaultMetricsWindow, DefaultMetricsMaxSize),
alertBus: make(chan BehavioralAlert, alertBufSize),
interval: 1 * time.Minute,
component: component,
logger: slog.Default().With("component", "sarl-behavioral"),
}
}
// AlertBus returns the channel for consuming behavioral alerts.
func (ba *BehavioralAnalyzer) AlertBus() <-chan BehavioralAlert {
return ba.alertBus
}
// Start begins continuous behavioral monitoring. Blocks until ctx cancelled.
func (ba *BehavioralAnalyzer) Start(ctx context.Context) {
ba.logger.Info("behavioral analyzer started", "interval", ba.interval)
ticker := time.NewTicker(ba.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
ba.logger.Info("behavioral analyzer stopped")
return
case <-ticker.C:
ba.collectAndAnalyze()
}
}
}
// collectAndAnalyze profiles runtime and checks for anomalies.
func (ba *BehavioralAnalyzer) collectAndAnalyze() {
profile := ba.collectProfile()
ba.storeMetrics(profile)
ba.detectAnomalies(profile)
}
// collectProfile gathers current Go runtime stats.
func (ba *BehavioralAnalyzer) collectProfile() BehaviorProfile {
var mem runtime.MemStats
runtime.ReadMemStats(&mem)
return BehaviorProfile{
Goroutines: runtime.NumGoroutine(),
HeapAllocMB: float64(mem.HeapAlloc) / (1024 * 1024),
HeapObjectsK: float64(mem.HeapObjects) / 1000,
GCPauseMs: float64(mem.PauseNs[(mem.NumGC+255)%256]) / 1e6,
NumGC: mem.NumGC,
}
}
// storeMetrics records profile data in the time-series DB.
func (ba *BehavioralAnalyzer) storeMetrics(p BehaviorProfile) {
ba.metricsDB.AddDataPoint(ba.component, "goroutines", float64(p.Goroutines))
ba.metricsDB.AddDataPoint(ba.component, "heap_alloc_mb", p.HeapAllocMB)
ba.metricsDB.AddDataPoint(ba.component, "heap_objects_k", p.HeapObjectsK)
ba.metricsDB.AddDataPoint(ba.component, "gc_pause_ms", p.GCPauseMs)
}
// detectAnomalies checks each metric against its baseline via Z-score.
func (ba *BehavioralAnalyzer) detectAnomalies(p BehaviorProfile) {
checks := []struct {
metric string
value float64
anomalyType string
severity string
}{
{"goroutines", float64(p.Goroutines), "goroutine_leak", "WARNING"},
{"heap_alloc_mb", p.HeapAllocMB, "memory_leak", "CRITICAL"},
{"heap_objects_k", p.HeapObjectsK, "object_leak", "WARNING"},
{"gc_pause_ms", p.GCPauseMs, "gc_pressure", "WARNING"},
}
for _, c := range checks {
baseline := ba.metricsDB.GetBaseline(ba.component, c.metric, DefaultMetricsWindow)
if !IsAnomaly(c.value, baseline, AnomalyZScoreThreshold) {
continue
}
zscore := CalculateZScore(c.value, baseline)
alert := BehavioralAlert{
Component: ba.component,
AnomalyType: c.anomalyType,
Metric: c.metric,
Current: c.value,
Baseline: baseline.Mean,
ZScore: zscore,
Severity: c.severity,
Timestamp: time.Now(),
}
select {
case ba.alertBus <- alert:
ba.logger.Warn("behavioral anomaly detected",
"type", c.anomalyType,
"metric", c.metric,
"z_score", zscore,
)
default:
ba.logger.Error("behavioral alert bus full")
}
}
}
// InjectMetric allows manually injecting a metric for testing.
func (ba *BehavioralAnalyzer) InjectMetric(metric string, value float64) {
ba.metricsDB.AddDataPoint(ba.component, metric, value)
}
// CurrentProfile returns a snapshot of the current runtime profile.
func (ba *BehavioralAnalyzer) CurrentProfile() BehaviorProfile {
return ba.collectProfile()
}

View file

@ -0,0 +1,206 @@
package resilience
import (
"context"
"testing"
"time"
)
// IM-01: Goroutine leak detection.
func TestBehavioral_IM01_GoroutineLeak(t *testing.T) {
ba := NewBehavioralAnalyzer("soc-ingest", 10)
// Build baseline of 10 goroutines.
for i := 0; i < 50; i++ {
ba.InjectMetric("goroutines", 10)
}
// Spike to 1000 goroutines — should trigger anomaly.
ba.metricsDB.AddDataPoint("soc-ingest", "goroutines", 1000)
profile := BehaviorProfile{Goroutines: 1000}
ba.detectAnomalies(profile)
select {
case alert := <-ba.alertBus:
if alert.AnomalyType != "goroutine_leak" {
t.Errorf("expected goroutine_leak, got %s", alert.AnomalyType)
}
if alert.ZScore <= 3 {
t.Errorf("expected Z > 3, got %f", alert.ZScore)
}
default:
t.Error("expected goroutine leak alert")
}
}
// IM-02: Memory leak detection.
func TestBehavioral_IM02_MemoryLeak(t *testing.T) {
ba := NewBehavioralAnalyzer("soc-correlate", 10)
// Baseline: 50 MB.
for i := 0; i < 50; i++ {
ba.InjectMetric("heap_alloc_mb", 50)
}
// Spike to 500 MB.
ba.metricsDB.AddDataPoint("soc-correlate", "heap_alloc_mb", 500)
profile := BehaviorProfile{HeapAllocMB: 500}
ba.detectAnomalies(profile)
select {
case alert := <-ba.alertBus:
if alert.AnomalyType != "memory_leak" {
t.Errorf("expected memory_leak, got %s", alert.AnomalyType)
}
if alert.Severity != "CRITICAL" {
t.Errorf("expected CRITICAL severity, got %s", alert.Severity)
}
default:
t.Error("expected memory leak alert")
}
}
// IM-03: GC pressure detection.
func TestBehavioral_IM03_GCPressure(t *testing.T) {
ba := NewBehavioralAnalyzer("soc-respond", 10)
// Baseline: 1ms GC pause.
for i := 0; i < 50; i++ {
ba.InjectMetric("gc_pause_ms", 1)
}
// Spike to 100ms.
ba.metricsDB.AddDataPoint("soc-respond", "gc_pause_ms", 100)
profile := BehaviorProfile{GCPauseMs: 100}
ba.detectAnomalies(profile)
select {
case alert := <-ba.alertBus:
if alert.AnomalyType != "gc_pressure" {
t.Errorf("expected gc_pressure, got %s", alert.AnomalyType)
}
default:
t.Error("expected gc_pressure alert")
}
}
// IM-04: Object leak detection.
func TestBehavioral_IM04_ObjectLeak(t *testing.T) {
ba := NewBehavioralAnalyzer("shield", 10)
for i := 0; i < 50; i++ {
ba.InjectMetric("heap_objects_k", 100)
}
ba.metricsDB.AddDataPoint("shield", "heap_objects_k", 5000)
profile := BehaviorProfile{HeapObjectsK: 5000}
ba.detectAnomalies(profile)
select {
case alert := <-ba.alertBus:
if alert.AnomalyType != "object_leak" {
t.Errorf("expected object_leak, got %s", alert.AnomalyType)
}
default:
t.Error("expected object leak alert")
}
}
// IM-05: Normal behavior — no alerts.
func TestBehavioral_IM05_NormalBehavior(t *testing.T) {
ba := NewBehavioralAnalyzer("sidecar", 10)
for i := 0; i < 50; i++ {
ba.InjectMetric("goroutines", 10)
ba.InjectMetric("heap_alloc_mb", 50)
ba.InjectMetric("heap_objects_k", 100)
ba.InjectMetric("gc_pause_ms", 1)
}
profile := BehaviorProfile{
Goroutines: 10,
HeapAllocMB: 50,
HeapObjectsK: 100,
GCPauseMs: 1,
}
ba.detectAnomalies(profile)
select {
case alert := <-ba.alertBus:
t.Errorf("expected no alerts for normal behavior, got %+v", alert)
default:
// Good — no alerts.
}
}
// IM-06: Start/Stop lifecycle.
func TestBehavioral_IM06_StartStop(t *testing.T) {
ba := NewBehavioralAnalyzer("test", 10)
ba.interval = 50 * time.Millisecond
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
ba.Start(ctx)
close(done)
}()
time.Sleep(100 * time.Millisecond)
cancel()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("Start() did not return after context cancellation")
}
}
// IM-07: CurrentProfile returns valid data.
func TestBehavioral_IM07_CurrentProfile(t *testing.T) {
ba := NewBehavioralAnalyzer("test", 10)
profile := ba.CurrentProfile()
if profile.Goroutines <= 0 {
t.Error("expected positive goroutine count")
}
if profile.HeapAllocMB <= 0 {
t.Error("expected positive heap alloc")
}
}
// IM-08: Alert bus overflow (non-blocking).
func TestBehavioral_IM08_AlertBusOverflow(t *testing.T) {
ba := NewBehavioralAnalyzer("test", 2)
// Fill bus.
ba.alertBus <- BehavioralAlert{AnomalyType: "fill1"}
ba.alertBus <- BehavioralAlert{AnomalyType: "fill2"}
// Build baseline.
for i := 0; i < 50; i++ {
ba.InjectMetric("goroutines", 10)
}
// This should not panic.
ba.metricsDB.AddDataPoint("test", "goroutines", 10000)
ba.detectAnomalies(BehaviorProfile{Goroutines: 10000})
}
// Test collectAndAnalyze runs without error.
func TestBehavioral_CollectAndAnalyze(t *testing.T) {
ba := NewBehavioralAnalyzer("test", 10)
// Should not panic.
ba.collectAndAnalyze()
}
// Test InjectMetric stores data.
func TestBehavioral_InjectMetric(t *testing.T) {
ba := NewBehavioralAnalyzer("test", 10)
ba.InjectMetric("custom", 42.0)
recent := ba.metricsDB.GetRecent("test", "custom", 1)
if len(recent) != 1 || recent[0].Value != 42.0 {
t.Errorf("expected 42.0, got %v", recent)
}
}

View file

@ -0,0 +1,524 @@
package resilience
import (
"context"
"fmt"
"log/slog"
"sync"
"time"
)
// HealingState represents the FSM state of a healing operation.
type HealingState string
const (
HealingIdle HealingState = "IDLE"
HealingDiagnosing HealingState = "DIAGNOSING"
HealingActive HealingState = "HEALING"
HealingVerifying HealingState = "VERIFYING"
HealingCompleted HealingState = "COMPLETED"
HealingFailed HealingState = "FAILED"
)
// HealingResult summarizes a completed healing operation.
type HealingResult string
const (
ResultSuccess HealingResult = "SUCCESS"
ResultFailed HealingResult = "FAILED"
ResultSkipped HealingResult = "SKIPPED"
)
// ActionType defines the kinds of healing actions.
type ActionType string
const (
ActionGracefulStop ActionType = "graceful_stop"
ActionClearTempFiles ActionType = "clear_temp_files"
ActionStartComponent ActionType = "start_component"
ActionVerifyHealth ActionType = "verify_health"
ActionNotifySOC ActionType = "notify_soc"
ActionFreezeConfig ActionType = "freeze_config"
ActionRollbackConfig ActionType = "rollback_config"
ActionVerifyConfig ActionType = "verify_config"
ActionSwitchReadOnly ActionType = "switch_to_readonly"
ActionBackupDB ActionType = "backup_db"
ActionRestoreSnapshot ActionType = "restore_snapshot"
ActionVerifyIntegrity ActionType = "verify_integrity"
ActionResumeWrites ActionType = "resume_writes"
ActionDisableRules ActionType = "disable_rules"
ActionRevertRules ActionType = "revert_rules"
ActionReloadEngine ActionType = "reload_engine"
ActionIsolateNetwork ActionType = "isolate_network"
ActionRegenCerts ActionType = "regenerate_certs"
ActionRestoreNetwork ActionType = "restore_network"
ActionNotifyArchitect ActionType = "notify_architect"
ActionEnterSafeMode ActionType = "enter_safe_mode"
)
// Action is a single step in a healing strategy.
type Action struct {
Type ActionType `json:"type"`
Params map[string]interface{} `json:"params,omitempty"`
Timeout time.Duration `json:"timeout"`
OnError string `json:"on_error"` // "continue", "abort", "rollback"
}
// TriggerCondition defines when a healing strategy activates.
type TriggerCondition struct {
Metrics []string `json:"metrics,omitempty"`
Statuses []ComponentStatus `json:"statuses,omitempty"`
ConsecutiveFailures int `json:"consecutive_failures"`
WithinWindow time.Duration `json:"within_window"`
}
// RollbackPlan defines what happens if healing fails.
type RollbackPlan struct {
OnFailure string `json:"on_failure"` // "escalate", "enter_safe_mode", "maintain_isolation"
Actions []Action `json:"actions,omitempty"`
}
// HealingStrategy is a complete self-healing plan.
type HealingStrategy struct {
ID string `json:"id"`
Name string `json:"name"`
Trigger TriggerCondition `json:"trigger"`
Actions []Action `json:"actions"`
Rollback RollbackPlan `json:"rollback"`
MaxAttempts int `json:"max_attempts"`
Cooldown time.Duration `json:"cooldown"`
}
// Diagnosis is the result of root cause analysis.
type Diagnosis struct {
Component string `json:"component"`
Metric string `json:"metric"`
RootCause string `json:"root_cause"`
Confidence float64 `json:"confidence"`
SuggestedFix string `json:"suggested_fix"`
RelatedAlerts []HealthAlert `json:"related_alerts,omitempty"`
}
// HealingOperation tracks a single healing attempt.
type HealingOperation struct {
ID string `json:"id"`
StrategyID string `json:"strategy_id"`
Component string `json:"component"`
State HealingState `json:"state"`
Diagnosis *Diagnosis `json:"diagnosis,omitempty"`
ActionsRun []ActionLog `json:"actions_run"`
Result HealingResult `json:"result"`
StartedAt time.Time `json:"started_at"`
CompletedAt time.Time `json:"completed_at,omitempty"`
Error string `json:"error,omitempty"`
AttemptNumber int `json:"attempt_number"`
}
// ActionLog records the execution of a single action.
type ActionLog struct {
Action ActionType `json:"action"`
StartedAt time.Time `json:"started_at"`
Duration time.Duration `json:"duration"`
Success bool `json:"success"`
Error string `json:"error,omitempty"`
}
// ActionExecutorFunc is the callback that actually runs an action.
// Implementations handle the real system operations (restart, rollback, etc.).
type ActionExecutorFunc func(ctx context.Context, action Action, component string) error
// HealingEngine is the L2 Self-Healing orchestrator.
type HealingEngine struct {
mu sync.RWMutex
strategies []HealingStrategy
cooldowns map[string]time.Time // strategyID → earliest next run
operations []*HealingOperation
opCounter int64
executor ActionExecutorFunc
alertBus <-chan HealthAlert
escalateFn func(HealthAlert) // called on unrecoverable failure
logger *slog.Logger
}
// NewHealingEngine creates a new self-healing engine.
func NewHealingEngine(
alertBus <-chan HealthAlert,
executor ActionExecutorFunc,
escalateFn func(HealthAlert),
) *HealingEngine {
return &HealingEngine{
cooldowns: make(map[string]time.Time),
operations: make([]*HealingOperation, 0),
executor: executor,
alertBus: alertBus,
escalateFn: escalateFn,
logger: slog.Default().With("component", "sarl-healing-engine"),
}
}
// RegisterStrategy adds a healing strategy.
func (he *HealingEngine) RegisterStrategy(s HealingStrategy) {
he.mu.Lock()
defer he.mu.Unlock()
he.strategies = append(he.strategies, s)
he.logger.Info("strategy registered", "id", s.ID, "name", s.Name)
}
// Start begins listening for alerts and initiating healing. Blocks until ctx is cancelled.
func (he *HealingEngine) Start(ctx context.Context) {
he.logger.Info("healing engine started", "strategies", len(he.strategies))
for {
select {
case <-ctx.Done():
he.logger.Info("healing engine stopped")
return
case alert, ok := <-he.alertBus:
if !ok {
return
}
if alert.Severity == SeverityCritical || alert.Severity == SeverityWarning {
he.initiateHealing(ctx, alert)
}
}
}
}
// initiateHealing runs the healing pipeline for an alert.
func (he *HealingEngine) initiateHealing(ctx context.Context, alert HealthAlert) {
strategy := he.findStrategy(alert)
if strategy == nil {
he.logger.Info("no matching strategy for alert",
"component", alert.Component,
"metric", alert.Metric,
)
return
}
if he.isInCooldown(strategy.ID) {
he.logger.Info("strategy in cooldown",
"strategy", strategy.ID,
"component", alert.Component,
)
return
}
op := he.createOperation(strategy, alert.Component)
he.logger.Info("healing initiated",
"op_id", op.ID,
"strategy", strategy.ID,
"component", alert.Component,
)
// Phase 1: Diagnose.
he.transitionOp(op, HealingDiagnosing)
diagnosis := he.diagnose(alert)
op.Diagnosis = &diagnosis
// Phase 2: Execute healing actions.
he.transitionOp(op, HealingActive)
execErr := he.executeActions(ctx, strategy, op)
// Phase 3: Verify recovery.
if execErr == nil {
he.transitionOp(op, HealingVerifying)
verifyErr := he.verifyRecovery(ctx, strategy, op.Component)
if verifyErr != nil {
execErr = verifyErr
}
}
// Phase 4: Complete or fail.
if execErr == nil {
he.transitionOp(op, HealingCompleted)
op.Result = ResultSuccess
he.logger.Info("healing completed successfully",
"op_id", op.ID,
"component", op.Component,
"duration", time.Since(op.StartedAt),
)
} else {
he.transitionOp(op, HealingFailed)
op.Result = ResultFailed
op.Error = execErr.Error()
he.logger.Error("healing failed",
"op_id", op.ID,
"component", op.Component,
"error", execErr,
)
// Execute rollback.
he.executeRollback(ctx, strategy, op)
// Escalate.
if he.escalateFn != nil {
he.escalateFn(alert)
}
}
op.CompletedAt = time.Now()
he.setCooldown(strategy.ID, strategy.Cooldown)
}
// findStrategy returns the first matching strategy for an alert.
func (he *HealingEngine) findStrategy(alert HealthAlert) *HealingStrategy {
he.mu.RLock()
defer he.mu.RUnlock()
for i := range he.strategies {
s := &he.strategies[i]
if he.matchesTrigger(s.Trigger, alert) {
return s
}
}
return nil
}
// matchesTrigger checks if an alert matches a strategy's trigger condition.
func (he *HealingEngine) matchesTrigger(trigger TriggerCondition, alert HealthAlert) bool {
// Match by metric name.
for _, m := range trigger.Metrics {
if m == alert.Metric {
return true
}
}
// Match by component status.
for _, s := range trigger.Statuses {
switch s {
case StatusCritical:
if alert.Severity == SeverityCritical {
return true
}
case StatusOffline:
if alert.Severity == SeverityCritical && alert.SuggestedAction == "restart" {
return true
}
}
}
return false
}
// isInCooldown checks if a strategy is still in its cooldown period.
func (he *HealingEngine) isInCooldown(strategyID string) bool {
he.mu.RLock()
defer he.mu.RUnlock()
earliest, ok := he.cooldowns[strategyID]
return ok && time.Now().Before(earliest)
}
// setCooldown marks a strategy as cooling down.
func (he *HealingEngine) setCooldown(strategyID string, duration time.Duration) {
he.mu.Lock()
defer he.mu.Unlock()
he.cooldowns[strategyID] = time.Now().Add(duration)
}
// createOperation creates and records a new healing operation.
func (he *HealingEngine) createOperation(strategy *HealingStrategy, component string) *HealingOperation {
he.mu.Lock()
defer he.mu.Unlock()
he.opCounter++
op := &HealingOperation{
ID: fmt.Sprintf("heal-%d", he.opCounter),
StrategyID: strategy.ID,
Component: component,
State: HealingIdle,
StartedAt: time.Now(),
ActionsRun: make([]ActionLog, 0),
}
he.operations = append(he.operations, op)
return op
}
// transitionOp moves an operation to a new state.
func (he *HealingEngine) transitionOp(op *HealingOperation, newState HealingState) {
he.logger.Debug("healing state transition",
"op_id", op.ID,
"from", op.State,
"to", newState,
)
op.State = newState
}
// diagnose performs root cause analysis for an alert.
func (he *HealingEngine) diagnose(alert HealthAlert) Diagnosis {
rootCause := "unknown"
confidence := 0.5
suggestedFix := "restart component"
switch {
case alert.Metric == "memory" && alert.Current > 90:
rootCause = "memory_exhaustion"
confidence = 0.9
suggestedFix = "restart with increased limits"
case alert.Metric == "cpu" && alert.Current > 90:
rootCause = "cpu_saturation"
confidence = 0.8
suggestedFix = "check for runaway goroutines"
case alert.Metric == "error_rate":
rootCause = "elevated_error_rate"
confidence = 0.7
suggestedFix = "check dependencies and config"
case alert.Metric == "latency_p99":
rootCause = "latency_degradation"
confidence = 0.6
suggestedFix = "check database and network"
case alert.Metric == "quorum":
rootCause = "quorum_loss"
confidence = 0.95
suggestedFix = "activate safe mode"
default:
rootCause = fmt.Sprintf("threshold_breach_%s", alert.Metric)
confidence = 0.5
suggestedFix = "investigate manually"
}
return Diagnosis{
Component: alert.Component,
Metric: alert.Metric,
RootCause: rootCause,
Confidence: confidence,
SuggestedFix: suggestedFix,
}
}
// executeActions runs each action in sequence.
func (he *HealingEngine) executeActions(ctx context.Context, strategy *HealingStrategy, op *HealingOperation) error {
for _, action := range strategy.Actions {
actionCtx := ctx
var cancel context.CancelFunc
if action.Timeout > 0 {
actionCtx, cancel = context.WithTimeout(ctx, action.Timeout)
}
start := time.Now()
err := he.executor(actionCtx, action, op.Component)
duration := time.Since(start)
if cancel != nil {
cancel()
}
logEntry := ActionLog{
Action: action.Type,
StartedAt: start,
Duration: duration,
Success: err == nil,
}
if err != nil {
logEntry.Error = err.Error()
}
op.ActionsRun = append(op.ActionsRun, logEntry)
if err != nil {
switch action.OnError {
case "continue":
he.logger.Warn("action failed, continuing",
"action", action.Type,
"error", err,
)
case "rollback":
return fmt.Errorf("action %s failed (rollback): %w", action.Type, err)
default: // "abort"
return fmt.Errorf("action %s failed: %w", action.Type, err)
}
}
}
return nil
}
// verifyRecovery checks if the component is healthy after healing.
func (he *HealingEngine) verifyRecovery(ctx context.Context, strategy *HealingStrategy, component string) error {
// Execute a verify_health action if not already in the strategy.
verifyAction := Action{
Type: ActionVerifyHealth,
Timeout: 30 * time.Second,
}
return he.executor(ctx, verifyAction, component)
}
// executeRollback runs the rollback plan for a failed healing.
func (he *HealingEngine) executeRollback(ctx context.Context, strategy *HealingStrategy, op *HealingOperation) {
if len(strategy.Rollback.Actions) == 0 {
he.logger.Info("no rollback actions defined",
"strategy", strategy.ID,
)
return
}
he.logger.Warn("executing rollback",
"strategy", strategy.ID,
"component", op.Component,
)
for _, action := range strategy.Rollback.Actions {
if err := he.executor(ctx, action, op.Component); err != nil {
he.logger.Error("rollback action failed",
"action", action.Type,
"error", err,
)
}
}
}
// GetOperation returns a healing operation by ID.
// Returns a deep copy to prevent data races with the healing goroutine.
func (he *HealingEngine) GetOperation(id string) (*HealingOperation, bool) {
he.mu.RLock()
defer he.mu.RUnlock()
for _, op := range he.operations {
if op.ID == id {
cp := *op
cp.ActionsRun = make([]ActionLog, len(op.ActionsRun))
copy(cp.ActionsRun, op.ActionsRun)
if op.Diagnosis != nil {
diag := *op.Diagnosis
cp.Diagnosis = &diag
}
return &cp, true
}
}
return nil, false
}
// RecentOperations returns the last N operations.
// Returns deep copies to prevent data races with the healing goroutine.
func (he *HealingEngine) RecentOperations(n int) []HealingOperation {
he.mu.RLock()
defer he.mu.RUnlock()
total := len(he.operations)
if total == 0 {
return nil
}
start := total - n
if start < 0 {
start = 0
}
result := make([]HealingOperation, 0, n)
for i := start; i < total; i++ {
cp := *he.operations[i]
cp.ActionsRun = make([]ActionLog, len(he.operations[i].ActionsRun))
copy(cp.ActionsRun, he.operations[i].ActionsRun)
if he.operations[i].Diagnosis != nil {
diag := *he.operations[i].Diagnosis
cp.Diagnosis = &diag
}
result = append(result, cp)
}
return result
}
// StrategyCount returns the number of registered strategies.
func (he *HealingEngine) StrategyCount() int {
he.mu.RLock()
defer he.mu.RUnlock()
return len(he.strategies)
}

View file

@ -0,0 +1,588 @@
package resilience
import (
"context"
"fmt"
"sync/atomic"
"testing"
"time"
)
// --- Mock executor for tests ---
type mockExecutorLog struct {
actions []ActionType
fail map[ActionType]bool
count atomic.Int64
}
func newMockExecutor() *mockExecutorLog {
return &mockExecutorLog{
fail: make(map[ActionType]bool),
}
}
func (m *mockExecutorLog) execute(_ context.Context, action Action, _ string) error {
m.count.Add(1)
m.actions = append(m.actions, action.Type)
if m.fail[action.Type] {
return fmt.Errorf("action %s failed", action.Type)
}
return nil
}
// --- Healing Engine Tests ---
// HE-01: Component restart (success).
func TestHealingEngine_HE01_RestartSuccess(t *testing.T) {
mock := newMockExecutor()
alertCh := make(chan HealthAlert, 10)
escalated := false
he := NewHealingEngine(alertCh, mock.execute, func(_ HealthAlert) {
escalated = true
})
he.RegisterStrategy(RestartComponentStrategy())
alertCh <- HealthAlert{
Component: "soc-ingest",
Severity: SeverityCritical,
Metric: "quorum",
SuggestedAction: "restart",
Timestamp: time.Now(),
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
// Run one healing cycle.
go he.Start(ctx)
time.Sleep(200 * time.Millisecond)
cancel()
ops := he.RecentOperations(10)
if len(ops) == 0 {
t.Fatal("expected at least 1 operation")
}
if ops[0].Result != ResultSuccess {
t.Errorf("expected SUCCESS, got %s (error: %s)", ops[0].Result, ops[0].Error)
}
if escalated {
t.Error("should not have escalated on success")
}
}
// HE-02: Component restart (failure ×3 → escalate).
func TestHealingEngine_HE02_RestartFailureEscalate(t *testing.T) {
mock := newMockExecutor()
mock.fail[ActionStartComponent] = true // Start always fails.
alertCh := make(chan HealthAlert, 10)
escalated := false
he := NewHealingEngine(alertCh, mock.execute, func(_ HealthAlert) {
escalated = true
})
he.RegisterStrategy(RestartComponentStrategy())
alertCh <- HealthAlert{
Component: "soc-correlate",
Severity: SeverityCritical,
Metric: "quorum",
SuggestedAction: "restart",
Timestamp: time.Now(),
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
go he.Start(ctx)
time.Sleep(200 * time.Millisecond)
cancel()
if !escalated {
t.Error("expected escalation on failure")
}
ops := he.RecentOperations(10)
if len(ops) == 0 {
t.Fatal("expected operation")
}
if ops[0].Result != ResultFailed {
t.Errorf("expected FAILED, got %s", ops[0].Result)
}
}
// HE-03: Config rollback strategy matching.
func TestHealingEngine_HE03_ConfigRollback(t *testing.T) {
mock := newMockExecutor()
alertCh := make(chan HealthAlert, 10)
he := NewHealingEngine(alertCh, mock.execute, nil)
he.RegisterStrategy(RollbackConfigStrategy())
alertCh <- HealthAlert{
Component: "soc-ingest",
Severity: SeverityWarning,
Metric: "config_tampering",
Timestamp: time.Now(),
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
go he.Start(ctx)
time.Sleep(200 * time.Millisecond)
cancel()
ops := he.RecentOperations(10)
if len(ops) == 0 {
t.Fatal("expected operation for config rollback")
}
if ops[0].StrategyID != "ROLLBACK_CONFIG" {
t.Errorf("expected ROLLBACK_CONFIG, got %s", ops[0].StrategyID)
}
}
// HE-04: Database recovery.
func TestHealingEngine_HE04_DatabaseRecovery(t *testing.T) {
mock := newMockExecutor()
alertCh := make(chan HealthAlert, 10)
he := NewHealingEngine(alertCh, mock.execute, nil)
he.RegisterStrategy(RecoverDatabaseStrategy())
alertCh <- HealthAlert{
Component: "soc-correlate",
Severity: SeverityCritical,
Metric: "database_corruption",
Timestamp: time.Now(),
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
go he.Start(ctx)
time.Sleep(200 * time.Millisecond)
cancel()
ops := he.RecentOperations(10)
if len(ops) == 0 {
t.Fatal("expected DB recovery op")
}
if ops[0].StrategyID != "RECOVER_DATABASE" {
t.Errorf("expected RECOVER_DATABASE, got %s", ops[0].StrategyID)
}
}
// HE-05: Rule poisoning defense.
func TestHealingEngine_HE05_RulePoisoning(t *testing.T) {
mock := newMockExecutor()
alertCh := make(chan HealthAlert, 10)
he := NewHealingEngine(alertCh, mock.execute, nil)
he.RegisterStrategy(RecoverRulesStrategy())
alertCh <- HealthAlert{
Component: "soc-correlate",
Severity: SeverityWarning,
Metric: "rule_execution_failure_rate",
Timestamp: time.Now(),
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
go he.Start(ctx)
time.Sleep(200 * time.Millisecond)
cancel()
ops := he.RecentOperations(10)
if len(ops) == 0 {
t.Fatal("expected rule recovery op")
}
if ops[0].StrategyID != "RECOVER_RULES" {
t.Errorf("expected RECOVER_RULES, got %s", ops[0].StrategyID)
}
}
// HE-06: Network isolation recovery.
func TestHealingEngine_HE06_NetworkRecovery(t *testing.T) {
mock := newMockExecutor()
alertCh := make(chan HealthAlert, 10)
he := NewHealingEngine(alertCh, mock.execute, nil)
he.RegisterStrategy(RecoverNetworkStrategy())
alertCh <- HealthAlert{
Component: "soc-respond",
Severity: SeverityWarning,
Metric: "network_partition",
Timestamp: time.Now(),
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
go he.Start(ctx)
time.Sleep(200 * time.Millisecond)
cancel()
ops := he.RecentOperations(10)
if len(ops) == 0 {
t.Fatal("expected network recovery op")
}
if ops[0].StrategyID != "RECOVER_NETWORK" {
t.Errorf("expected RECOVER_NETWORK, got %s", ops[0].StrategyID)
}
}
// HE-07: Cooldown enforcement.
func TestHealingEngine_HE07_Cooldown(t *testing.T) {
mock := newMockExecutor()
alertCh := make(chan HealthAlert, 10)
he := NewHealingEngine(alertCh, mock.execute, nil)
he.RegisterStrategy(RestartComponentStrategy())
// Set cooldown manually.
he.setCooldown("RESTART_COMPONENT", 1*time.Hour)
if !he.isInCooldown("RESTART_COMPONENT") {
t.Error("expected cooldown active")
}
alertCh <- HealthAlert{
Component: "soc-ingest",
Severity: SeverityCritical,
Metric: "quorum",
SuggestedAction: "restart",
Timestamp: time.Now(),
}
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
go he.Start(ctx)
time.Sleep(200 * time.Millisecond)
cancel()
ops := he.RecentOperations(10)
if len(ops) != 0 {
t.Error("expected 0 operations during cooldown")
}
}
// HE-08: Rollback on failure.
func TestHealingEngine_HE08_Rollback(t *testing.T) {
mock := newMockExecutor()
mock.fail[ActionStartComponent] = true
alertCh := make(chan HealthAlert, 10)
he := NewHealingEngine(alertCh, mock.execute, func(_ HealthAlert) {})
strategy := RollbackConfigStrategy()
he.RegisterStrategy(strategy)
alertCh <- HealthAlert{
Component: "soc-ingest",
Severity: SeverityWarning,
Metric: "config_tampering",
Timestamp: time.Now(),
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
go he.Start(ctx)
time.Sleep(200 * time.Millisecond)
cancel()
// Rollback should have executed enter_safe_mode.
foundSafeMode := false
for _, a := range mock.actions {
if a == ActionEnterSafeMode {
foundSafeMode = true
}
}
if !foundSafeMode {
t.Errorf("expected safe mode in rollback, actions: %v", mock.actions)
}
}
// HE-09: State machine transitions.
func TestHealingEngine_HE09_StateTransitions(t *testing.T) {
mock := newMockExecutor()
alertCh := make(chan HealthAlert, 10)
he := NewHealingEngine(alertCh, mock.execute, nil)
he.RegisterStrategy(RestartComponentStrategy())
alertCh <- HealthAlert{
Component: "comp",
Severity: SeverityCritical,
Metric: "quorum",
SuggestedAction: "restart",
Timestamp: time.Now(),
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
go he.Start(ctx)
time.Sleep(200 * time.Millisecond)
cancel()
ops := he.RecentOperations(10)
if len(ops) == 0 {
t.Fatal("expected operation")
}
// Final state should be COMPLETED.
if ops[0].State != HealingCompleted {
t.Errorf("expected COMPLETED, got %s", ops[0].State)
}
}
// HE-10: Audit logging — all actions recorded.
func TestHealingEngine_HE10_AuditLogging(t *testing.T) {
mock := newMockExecutor()
alertCh := make(chan HealthAlert, 10)
he := NewHealingEngine(alertCh, mock.execute, nil)
he.RegisterStrategy(RestartComponentStrategy())
alertCh <- HealthAlert{
Component: "comp",
Severity: SeverityCritical,
Metric: "quorum",
SuggestedAction: "restart",
Timestamp: time.Now(),
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
go he.Start(ctx)
time.Sleep(200 * time.Millisecond)
cancel()
ops := he.RecentOperations(10)
if len(ops) == 0 {
t.Fatal("expected operation")
}
if len(ops[0].ActionsRun) == 0 {
t.Error("expected action logs")
}
for _, al := range ops[0].ActionsRun {
if al.StartedAt.IsZero() {
t.Error("action log missing start time")
}
}
}
// HE-11: Parallel healing — no race conditions.
func TestHealingEngine_HE11_Parallel(t *testing.T) {
mock := newMockExecutor()
alertCh := make(chan HealthAlert, 100)
he := NewHealingEngine(alertCh, mock.execute, nil)
for _, s := range DefaultStrategies() {
he.RegisterStrategy(s)
}
// Send many alerts concurrently.
for i := 0; i < 10; i++ {
alertCh <- HealthAlert{
Component: fmt.Sprintf("comp-%d", i),
Severity: SeverityCritical,
Metric: "quorum",
SuggestedAction: "restart",
Timestamp: time.Now(),
}
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
go he.Start(ctx)
time.Sleep(1 * time.Second)
cancel()
// All 10 alerts processed (first gets an op, rest hit cooldown).
ops := he.RecentOperations(100)
if len(ops) == 0 {
t.Fatal("expected at least 1 operation")
}
}
// HE-12: No matching strategy → no operation.
func TestHealingEngine_HE12_NoStrategy(t *testing.T) {
mock := newMockExecutor()
alertCh := make(chan HealthAlert, 10)
he := NewHealingEngine(alertCh, mock.execute, nil)
// No strategies registered.
alertCh <- HealthAlert{
Component: "comp",
Severity: SeverityCritical,
Metric: "unknown_metric",
Timestamp: time.Now(),
}
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
go he.Start(ctx)
time.Sleep(200 * time.Millisecond)
cancel()
ops := he.RecentOperations(10)
if len(ops) != 0 {
t.Errorf("expected 0 operations, got %d", len(ops))
}
}
// Test diagnosis (various root causes).
func TestHealingEngine_Diagnosis(t *testing.T) {
mock := newMockExecutor()
he := NewHealingEngine(nil, mock.execute, nil)
tests := []struct {
metric string
current float64
wantCause string
}{
{"memory", 95, "memory_exhaustion"},
{"cpu", 95, "cpu_saturation"},
{"error_rate", 10, "elevated_error_rate"},
{"latency_p99", 200, "latency_degradation"},
{"quorum", 0.3, "quorum_loss"},
{"custom", 100, "threshold_breach_custom"},
}
for _, tt := range tests {
alert := HealthAlert{
Component: "test",
Metric: tt.metric,
Current: tt.current,
}
d := he.diagnose(alert)
if d.RootCause != tt.wantCause {
t.Errorf("metric=%s: expected %s, got %s", tt.metric, tt.wantCause, d.RootCause)
}
if d.Confidence <= 0 || d.Confidence > 1 {
t.Errorf("metric=%s: invalid confidence %f", tt.metric, d.Confidence)
}
}
}
// Test DefaultStrategies returns 5 strategies.
func TestDefaultStrategies(t *testing.T) {
strategies := DefaultStrategies()
if len(strategies) != 5 {
t.Errorf("expected 5 strategies, got %d", len(strategies))
}
ids := map[string]bool{}
for _, s := range strategies {
if ids[s.ID] {
t.Errorf("duplicate strategy ID: %s", s.ID)
}
ids[s.ID] = true
if s.MaxAttempts <= 0 {
t.Errorf("strategy %s: invalid max_attempts %d", s.ID, s.MaxAttempts)
}
if s.Cooldown <= 0 {
t.Errorf("strategy %s: invalid cooldown %v", s.ID, s.Cooldown)
}
if len(s.Actions) == 0 {
t.Errorf("strategy %s: no actions defined", s.ID)
}
}
}
// Test StrategyCount.
func TestHealingEngine_StrategyCount(t *testing.T) {
he := NewHealingEngine(nil, nil, nil)
if he.StrategyCount() != 0 {
t.Error("expected 0")
}
for _, s := range DefaultStrategies() {
he.RegisterStrategy(s)
}
if he.StrategyCount() != 5 {
t.Errorf("expected 5, got %d", he.StrategyCount())
}
}
// Test GetOperation.
func TestHealingEngine_GetOperation(t *testing.T) {
mock := newMockExecutor()
alertCh := make(chan HealthAlert, 10)
he := NewHealingEngine(alertCh, mock.execute, nil)
he.RegisterStrategy(RestartComponentStrategy())
alertCh <- HealthAlert{
Component: "comp",
Severity: SeverityCritical,
Metric: "quorum",
SuggestedAction: "restart",
Timestamp: time.Now(),
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
go he.Start(ctx)
time.Sleep(200 * time.Millisecond)
cancel()
op, ok := he.GetOperation("heal-1")
if !ok {
t.Fatal("expected operation heal-1")
}
if op.Component != "comp" {
t.Errorf("expected comp, got %s", op.Component)
}
_, ok = he.GetOperation("nonexistent")
if ok {
t.Error("expected not found for nonexistent")
}
}
// Test action OnError=continue.
func TestHealingEngine_ActionContinueOnError(t *testing.T) {
mock := newMockExecutor()
mock.fail[ActionGracefulStop] = true // First action fails but marked continue.
alertCh := make(chan HealthAlert, 10)
he := NewHealingEngine(alertCh, mock.execute, nil)
he.RegisterStrategy(RestartComponentStrategy())
alertCh <- HealthAlert{
Component: "comp",
Severity: SeverityCritical,
Metric: "quorum",
SuggestedAction: "restart",
Timestamp: time.Now(),
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
go he.Start(ctx)
time.Sleep(200 * time.Millisecond)
cancel()
ops := he.RecentOperations(10)
if len(ops) == 0 {
t.Fatal("expected operation")
}
// Should still succeed because graceful_stop has OnError=continue.
if ops[0].Result != ResultSuccess {
t.Errorf("expected SUCCESS (continue on error), got %s", ops[0].Result)
}
}

View file

@ -0,0 +1,215 @@
package resilience
import "time"
// Built-in healing strategies per ТЗ §4.1.1.
// These are registered at startup via HealingEngine.RegisterStrategy().
// DefaultStrategies returns the 5 built-in healing strategies.
func DefaultStrategies() []HealingStrategy {
return []HealingStrategy{
RestartComponentStrategy(),
RollbackConfigStrategy(),
RecoverDatabaseStrategy(),
RecoverRulesStrategy(),
RecoverNetworkStrategy(),
}
}
// RestartComponentStrategy handles component crashes and offline states.
// Trigger: component_offline OR component_critical, 2 consecutive failures within 5m.
// Actions: graceful_stop → clear_temp → start → verify → notify.
// Rollback: escalate to next strategy.
func RestartComponentStrategy() HealingStrategy {
return HealingStrategy{
ID: "RESTART_COMPONENT",
Name: "Component Restart",
Trigger: TriggerCondition{
Statuses: []ComponentStatus{StatusOffline, StatusCritical},
ConsecutiveFailures: 2,
WithinWindow: 5 * time.Minute,
},
Actions: []Action{
{Type: ActionGracefulStop, Timeout: 10 * time.Second, OnError: "continue"},
{Type: ActionClearTempFiles, Timeout: 5 * time.Second, OnError: "continue"},
{Type: ActionStartComponent, Timeout: 30 * time.Second, OnError: "abort"},
{Type: ActionVerifyHealth, Timeout: 60 * time.Second, OnError: "abort"},
{Type: ActionNotifySOC, Timeout: 5 * time.Second, OnError: "continue",
Params: map[string]interface{}{
"severity": "INFO",
"message": "Component restarted successfully",
},
},
},
Rollback: RollbackPlan{
OnFailure: "escalate",
Actions: []Action{
{Type: ActionNotifyArchitect, Timeout: 5 * time.Second,
Params: map[string]interface{}{
"severity": "CRITICAL",
"message": "Component restart failed after max attempts",
},
},
},
},
MaxAttempts: 3,
Cooldown: 5 * time.Minute,
}
}
// RollbackConfigStrategy handles config tampering or validation failures.
// Trigger: config_tampering_detected OR config_validation_failed.
// Actions: freeze → verify_backup → rollback → restart → verify → notify.
func RollbackConfigStrategy() HealingStrategy {
return HealingStrategy{
ID: "ROLLBACK_CONFIG",
Name: "Configuration Rollback",
Trigger: TriggerCondition{
Metrics: []string{"config_tampering", "config_validation"},
},
Actions: []Action{
{Type: ActionFreezeConfig, Timeout: 5 * time.Second, OnError: "abort"},
{Type: ActionRollbackConfig, Timeout: 15 * time.Second, OnError: "abort"},
{Type: ActionStartComponent, Timeout: 30 * time.Second, OnError: "rollback"},
{Type: ActionVerifyConfig, Timeout: 10 * time.Second, OnError: "abort"},
{Type: ActionNotifyArchitect, Timeout: 5 * time.Second, OnError: "continue",
Params: map[string]interface{}{
"severity": "WARNING",
"message": "Config rolled back due to tampering",
},
},
},
Rollback: RollbackPlan{
OnFailure: "enter_safe_mode",
Actions: []Action{
{Type: ActionEnterSafeMode, Timeout: 10 * time.Second},
},
},
MaxAttempts: 1,
Cooldown: 1 * time.Hour,
}
}
// RecoverDatabaseStrategy handles SQLite corruption.
// Trigger: database_corruption OR sqlite_integrity_failed.
// Actions: readonly → backup → restore → verify → resume → notify.
func RecoverDatabaseStrategy() HealingStrategy {
return HealingStrategy{
ID: "RECOVER_DATABASE",
Name: "Database Recovery",
Trigger: TriggerCondition{
Metrics: []string{"database_corruption", "sqlite_integrity"},
},
Actions: []Action{
{Type: ActionSwitchReadOnly, Timeout: 5 * time.Second, OnError: "abort"},
{Type: ActionBackupDB, Timeout: 30 * time.Second, OnError: "continue"},
{Type: ActionRestoreSnapshot, Timeout: 60 * time.Second, OnError: "abort",
Params: map[string]interface{}{
"snapshot_age_max": "1h",
},
},
{Type: ActionVerifyIntegrity, Timeout: 30 * time.Second, OnError: "abort"},
{Type: ActionResumeWrites, Timeout: 5 * time.Second, OnError: "abort"},
{Type: ActionNotifySOC, Timeout: 5 * time.Second, OnError: "continue",
Params: map[string]interface{}{
"severity": "WARNING",
"message": "Database recovered from snapshot",
},
},
},
Rollback: RollbackPlan{
OnFailure: "enter_lockdown",
Actions: []Action{
{Type: ActionEnterSafeMode, Timeout: 10 * time.Second},
{Type: ActionNotifyArchitect, Timeout: 5 * time.Second,
Params: map[string]interface{}{
"severity": "CRITICAL",
"message": "Database recovery failed",
},
},
},
},
MaxAttempts: 2,
Cooldown: 2 * time.Hour,
}
}
// RecoverRulesStrategy handles correlation rule poisoning.
// Trigger: rule execution failure rate > 50%.
// Actions: disable_suspicious → revert_baseline → verify → reload → notify.
func RecoverRulesStrategy() HealingStrategy {
return HealingStrategy{
ID: "RECOVER_RULES",
Name: "Rule Poisoning Defense",
Trigger: TriggerCondition{
Metrics: []string{"rule_execution_failure_rate", "correlation_rule_anomaly"},
},
Actions: []Action{
{Type: ActionDisableRules, Timeout: 10 * time.Second, OnError: "abort",
Params: map[string]interface{}{
"criteria": "failure_rate > 80%",
},
},
{Type: ActionRevertRules, Timeout: 15 * time.Second, OnError: "abort"},
{Type: ActionReloadEngine, Timeout: 30 * time.Second, OnError: "abort"},
{Type: ActionVerifyHealth, Timeout: 30 * time.Second, OnError: "continue"},
{Type: ActionNotifyArchitect, Timeout: 5 * time.Second, OnError: "continue",
Params: map[string]interface{}{
"severity": "WARNING",
"message": "Rules recovered from baseline",
},
},
},
Rollback: RollbackPlan{
OnFailure: "disable_correlation",
},
MaxAttempts: 2,
Cooldown: 4 * time.Hour,
}
}
// RecoverNetworkStrategy handles network partition or mTLS cert expiry.
// Trigger: network_partition_detected OR mTLS_cert_expired.
// Actions: isolate → regen_certs → verify → restore → notify.
func RecoverNetworkStrategy() HealingStrategy {
return HealingStrategy{
ID: "RECOVER_NETWORK",
Name: "Network Isolation Recovery",
Trigger: TriggerCondition{
Metrics: []string{"network_partition", "mtls_cert_expiry"},
},
Actions: []Action{
{Type: ActionIsolateNetwork, Timeout: 5 * time.Second, OnError: "abort",
Params: map[string]interface{}{
"scope": "external_only",
},
},
{Type: ActionRegenCerts, Timeout: 30 * time.Second, OnError: "abort",
Params: map[string]interface{}{
"validity": "24h",
},
},
{Type: ActionVerifyHealth, Timeout: 30 * time.Second, OnError: "rollback"},
{Type: ActionRestoreNetwork, Timeout: 10 * time.Second, OnError: "abort"},
{Type: ActionNotifySOC, Timeout: 5 * time.Second, OnError: "continue",
Params: map[string]interface{}{
"severity": "INFO",
"message": "Network connectivity restored",
},
},
},
Rollback: RollbackPlan{
OnFailure: "maintain_isolation",
Actions: []Action{
{Type: ActionNotifyArchitect, Timeout: 5 * time.Second,
Params: map[string]interface{}{
"severity": "CRITICAL",
"message": "Network recovery failed, maintaining isolation",
},
},
},
},
MaxAttempts: 3,
Cooldown: 1 * time.Hour,
}
}

View file

@ -0,0 +1,445 @@
package resilience
import (
"context"
"fmt"
"log/slog"
"sync"
"time"
)
// ComponentStatus defines the health state of a monitored component.
type ComponentStatus string
const (
StatusHealthy ComponentStatus = "HEALTHY"
StatusDegraded ComponentStatus = "DEGRADED"
StatusCritical ComponentStatus = "CRITICAL"
StatusOffline ComponentStatus = "OFFLINE"
)
// AlertSeverity defines the severity of a health alert.
type AlertSeverity string
const (
SeverityInfo AlertSeverity = "INFO"
SeverityWarning AlertSeverity = "WARNING"
SeverityCritical AlertSeverity = "CRITICAL"
)
// OverallStatus aggregates component statuses into a system-wide status.
type OverallStatus string
const (
OverallHealthy OverallStatus = "HEALTHY"
OverallDegraded OverallStatus = "DEGRADED"
OverallCritical OverallStatus = "CRITICAL"
)
// Default intervals per ТЗ §3.1.2.
const (
MetricsCollectionInterval = 10 * time.Second
HealthCheckInterval = 30 * time.Second
QuorumValidationInterval = 60 * time.Second
// AnomalyZScoreThreshold — Z > 3.0 = anomaly (99.7% confidence).
AnomalyZScoreThreshold = 3.0
// QuorumThreshold — 2/3 must be healthy.
QuorumThreshold = 0.66
// MaxConsecutiveFailures before marking CRITICAL.
MaxConsecutiveFailures = 3
)
// ComponentConfig defines monitoring thresholds for a component.
type ComponentConfig struct {
Name string `json:"name"`
Type string `json:"type"` // go_binary, c_binary, c_kernel_module
Thresholds map[string]float64 `json:"thresholds"`
// Whether threshold is an upper bound (true) or lower bound (false).
ThresholdIsMax map[string]bool `json:"threshold_is_max"`
}
// ComponentHealth tracks the health state of a single component.
type ComponentHealth struct {
Name string `json:"name"`
Status ComponentStatus `json:"status"`
Metrics map[string]float64 `json:"metrics"`
LastCheck time.Time `json:"last_check"`
Consecutive int `json:"consecutive_failures"`
Config ComponentConfig `json:"-"`
}
// HealthAlert represents a detected health anomaly.
type HealthAlert struct {
Component string `json:"component"`
Severity AlertSeverity `json:"severity"`
Metric string `json:"metric"`
Current float64 `json:"current"`
Threshold float64 `json:"threshold"`
ZScore float64 `json:"z_score,omitempty"`
Timestamp time.Time `json:"timestamp"`
SuggestedAction string `json:"suggested_action"`
}
// HealthResponse is the API response for GET /api/v1/resilience/health.
type HealthResponse struct {
OverallStatus OverallStatus `json:"overall_status"`
Components []ComponentHealth `json:"components"`
QuorumValid bool `json:"quorum_valid"`
LastCheck time.Time `json:"last_check"`
AnomaliesDetected []HealthAlert `json:"anomalies_detected"`
}
// MetricsCollector is the interface for collecting metrics from components.
// Implementations can use /healthz endpoints, /metrics, or runtime stats.
type MetricsCollector interface {
Collect(ctx context.Context, component string) (map[string]float64, error)
}
// HealthMonitor is the L1 Self-Monitoring orchestrator.
// It collects metrics, runs anomaly detection, validates quorum,
// and emits HealthAlerts to the alert bus.
type HealthMonitor struct {
mu sync.RWMutex
components map[string]*ComponentHealth
metricsDB *MetricsDB
alertBus chan HealthAlert
collector MetricsCollector
logger *slog.Logger
// anomalyWindow is the baseline window for Z-score calculation.
anomalyWindow time.Duration
}
// NewHealthMonitor creates a new health monitor.
func NewHealthMonitor(collector MetricsCollector, alertBufSize int) *HealthMonitor {
if alertBufSize <= 0 {
alertBufSize = 100
}
return &HealthMonitor{
components: make(map[string]*ComponentHealth),
metricsDB: NewMetricsDB(DefaultMetricsWindow, DefaultMetricsMaxSize),
alertBus: make(chan HealthAlert, alertBufSize),
collector: collector,
logger: slog.Default().With("component", "sarl-health-monitor"),
anomalyWindow: 24 * time.Hour,
}
}
// RegisterComponent adds a component to be monitored.
func (hm *HealthMonitor) RegisterComponent(config ComponentConfig) {
hm.mu.Lock()
defer hm.mu.Unlock()
hm.components[config.Name] = &ComponentHealth{
Name: config.Name,
Status: StatusHealthy,
Metrics: make(map[string]float64),
Config: config,
}
hm.logger.Info("component registered", "name", config.Name, "type", config.Type)
}
// AlertBus returns the channel for consuming health alerts.
func (hm *HealthMonitor) AlertBus() <-chan HealthAlert {
return hm.alertBus
}
// Start begins the monitoring loops. Blocks until ctx is cancelled.
func (hm *HealthMonitor) Start(ctx context.Context) {
hm.logger.Info("health monitor started")
metricsTicker := time.NewTicker(MetricsCollectionInterval)
healthTicker := time.NewTicker(HealthCheckInterval)
quorumTicker := time.NewTicker(QuorumValidationInterval)
defer metricsTicker.Stop()
defer healthTicker.Stop()
defer quorumTicker.Stop()
for {
select {
case <-ctx.Done():
hm.logger.Info("health monitor stopped")
return
case <-metricsTicker.C:
hm.collectMetrics(ctx)
case <-healthTicker.C:
hm.checkHealth()
case <-quorumTicker.C:
hm.validateQuorum()
}
}
}
// collectMetrics gathers metrics from all registered components.
func (hm *HealthMonitor) collectMetrics(ctx context.Context) {
hm.mu.RLock()
names := make([]string, 0, len(hm.components))
for name := range hm.components {
names = append(names, name)
}
hm.mu.RUnlock()
for _, name := range names {
metrics, err := hm.collector.Collect(ctx, name)
if err != nil {
hm.logger.Warn("metrics collection failed", "component", name, "error", err)
hm.mu.Lock()
if comp, ok := hm.components[name]; ok {
comp.Consecutive++
}
hm.mu.Unlock()
continue
}
hm.mu.Lock()
comp, ok := hm.components[name]
if ok {
comp.Metrics = metrics
comp.LastCheck = time.Now()
// Store each metric in time-series DB.
for metric, value := range metrics {
hm.metricsDB.AddDataPoint(name, metric, value)
}
}
hm.mu.Unlock()
}
}
// checkHealth evaluates each component against thresholds and anomalies.
func (hm *HealthMonitor) checkHealth() {
hm.mu.Lock()
defer hm.mu.Unlock()
for _, comp := range hm.components {
alerts := hm.evaluateComponent(comp)
for _, alert := range alerts {
hm.emitAlert(alert)
}
}
}
// evaluateComponent checks a single component's metrics against thresholds
// and runs Z-score anomaly detection. Returns any generated alerts.
func (hm *HealthMonitor) evaluateComponent(comp *ComponentHealth) []HealthAlert {
var alerts []HealthAlert
breached := false
for metric, value := range comp.Metrics {
threshold, hasThreshold := comp.Config.Thresholds[metric]
if !hasThreshold {
continue
}
isMax := comp.Config.ThresholdIsMax[metric]
var exceeded bool
if isMax {
exceeded = value > threshold
} else {
exceeded = value < threshold
}
if exceeded {
breached = true
action := "restart"
if metric == "error_rate" || metric == "latency_p99" {
action = "investigate"
}
alerts = append(alerts, HealthAlert{
Component: comp.Name,
Severity: SeverityWarning,
Metric: metric,
Current: value,
Threshold: threshold,
Timestamp: time.Now(),
SuggestedAction: action,
})
}
// Z-score anomaly detection.
baseline := hm.metricsDB.GetBaseline(comp.Name, metric, hm.anomalyWindow)
if IsAnomaly(value, baseline, AnomalyZScoreThreshold) {
zscore := CalculateZScore(value, baseline)
alerts = append(alerts, HealthAlert{
Component: comp.Name,
Severity: SeverityCritical,
Metric: metric,
Current: value,
Threshold: baseline.Mean + AnomalyZScoreThreshold*baseline.StdDev,
ZScore: zscore,
Timestamp: time.Now(),
SuggestedAction: fmt.Sprintf("anomaly detected (Z=%.2f), investigate %s", zscore, metric),
})
}
}
// Update component status.
if breached {
comp.Consecutive++
if comp.Consecutive >= MaxConsecutiveFailures {
comp.Status = StatusCritical
} else {
comp.Status = StatusDegraded
}
} else {
comp.Consecutive = 0
comp.Status = StatusHealthy
}
return alerts
}
// emitAlert sends an alert to the bus (non-blocking).
func (hm *HealthMonitor) emitAlert(alert HealthAlert) {
select {
case hm.alertBus <- alert:
hm.logger.Warn("health alert emitted",
"component", alert.Component,
"severity", alert.Severity,
"metric", alert.Metric,
"current", alert.Current,
"threshold", alert.Threshold,
)
default:
hm.logger.Error("alert bus full, dropping alert",
"component", alert.Component,
"metric", alert.Metric,
)
}
}
// validateQuorum checks if 2/3 of components are healthy.
func (hm *HealthMonitor) validateQuorum() {
hm.mu.RLock()
defer hm.mu.RUnlock()
if len(hm.components) == 0 {
return
}
valid := ValidateQuorum(hm.componentStatuses())
if !valid {
hm.logger.Error("QUORUM LOST — entering degraded state",
"healthy_ratio", hm.healthyRatio(),
"threshold", QuorumThreshold,
)
hm.emitAlert(HealthAlert{
Component: "system",
Severity: SeverityCritical,
Metric: "quorum",
Current: hm.healthyRatio(),
Threshold: QuorumThreshold,
Timestamp: time.Now(),
SuggestedAction: "activate safe mode",
})
}
}
// ValidateQuorum checks if the healthy ratio meets the 2/3 threshold.
func ValidateQuorum(statuses map[string]ComponentStatus) bool {
if len(statuses) == 0 {
return false
}
healthy := 0
for _, status := range statuses {
if status == StatusHealthy {
healthy++
}
}
return float64(healthy)/float64(len(statuses)) >= QuorumThreshold
}
// componentStatuses returns current status map (caller must hold RLock).
func (hm *HealthMonitor) componentStatuses() map[string]ComponentStatus {
statuses := make(map[string]ComponentStatus, len(hm.components))
for name, comp := range hm.components {
statuses[name] = comp.Status
}
return statuses
}
// healthyRatio returns the fraction of healthy components (caller must hold RLock).
func (hm *HealthMonitor) healthyRatio() float64 {
if len(hm.components) == 0 {
return 0
}
healthy := 0
for _, comp := range hm.components {
if comp.Status == StatusHealthy {
healthy++
}
}
return float64(healthy) / float64(len(hm.components))
}
// GetHealth returns a snapshot of the entire system health.
func (hm *HealthMonitor) GetHealth() HealthResponse {
hm.mu.RLock()
defer hm.mu.RUnlock()
components := make([]ComponentHealth, 0, len(hm.components))
for _, comp := range hm.components {
cp := *comp
// Deep copy metrics.
cp.Metrics = make(map[string]float64, len(comp.Metrics))
for k, v := range comp.Metrics {
cp.Metrics[k] = v
}
components = append(components, cp)
}
overall := OverallHealthy
for _, comp := range components {
switch comp.Status {
case StatusCritical, StatusOffline:
overall = OverallCritical
case StatusDegraded:
if overall != OverallCritical {
overall = OverallDegraded
}
}
}
return HealthResponse{
OverallStatus: overall,
Components: components,
QuorumValid: ValidateQuorum(hm.componentStatuses()),
LastCheck: time.Now(),
}
}
// SetComponentStatus manually sets a component's status (for testing/override).
func (hm *HealthMonitor) SetComponentStatus(name string, status ComponentStatus) {
hm.mu.Lock()
defer hm.mu.Unlock()
if comp, ok := hm.components[name]; ok {
comp.Status = status
}
}
// UpdateMetrics manually updates a component's metrics (for testing/override).
func (hm *HealthMonitor) UpdateMetrics(name string, metrics map[string]float64) {
hm.mu.Lock()
defer hm.mu.Unlock()
if comp, ok := hm.components[name]; ok {
comp.Metrics = metrics
comp.LastCheck = time.Now()
for metric, value := range metrics {
hm.metricsDB.AddDataPoint(name, metric, value)
}
}
}
// ComponentCount returns the number of registered components.
func (hm *HealthMonitor) ComponentCount() int {
hm.mu.RLock()
defer hm.mu.RUnlock()
return len(hm.components)
}

View file

@ -0,0 +1,499 @@
package resilience
import (
"context"
"fmt"
"math"
"testing"
"time"
)
// --- MetricsDB Tests ---
func TestRingBuffer_AddAndAll(t *testing.T) {
rb := newRingBuffer(5)
now := time.Now()
for i := 0; i < 3; i++ {
rb.Add(DataPoint{Timestamp: now.Add(time.Duration(i) * time.Second), Value: float64(i)})
}
if rb.Len() != 3 {
t.Fatalf("expected 3, got %d", rb.Len())
}
all := rb.All()
if len(all) != 3 {
t.Fatalf("expected 3 points, got %d", len(all))
}
for i, dp := range all {
if dp.Value != float64(i) {
t.Errorf("point %d: expected %f, got %f", i, float64(i), dp.Value)
}
}
}
func TestRingBuffer_Wrap(t *testing.T) {
rb := newRingBuffer(3)
now := time.Now()
for i := 0; i < 5; i++ {
rb.Add(DataPoint{Timestamp: now.Add(time.Duration(i) * time.Second), Value: float64(i)})
}
if rb.Len() != 3 {
t.Fatalf("expected 3 (buffer size), got %d", rb.Len())
}
all := rb.All()
// Should contain values 2, 3, 4 (oldest 0, 1 overwritten).
expected := []float64{2, 3, 4}
for i, dp := range all {
if dp.Value != expected[i] {
t.Errorf("point %d: expected %f, got %f", i, expected[i], dp.Value)
}
}
}
func TestMetricsDB_AddAndBaseline(t *testing.T) {
db := NewMetricsDB(time.Hour, 100)
for i := 0; i < 20; i++ {
db.AddDataPoint("soc-ingest", "cpu", 30.0+float64(i%5))
}
baseline := db.GetBaseline("soc-ingest", "cpu", time.Hour)
if baseline.Count != 20 {
t.Fatalf("expected 20 points, got %d", baseline.Count)
}
if baseline.Mean < 30 || baseline.Mean > 35 {
t.Errorf("mean out of expected range: %f", baseline.Mean)
}
if baseline.StdDev == 0 {
t.Error("expected non-zero stddev")
}
}
func TestMetricsDB_EmptyBaseline(t *testing.T) {
db := NewMetricsDB(time.Hour, 100)
baseline := db.GetBaseline("nonexistent", "cpu", time.Hour)
if baseline.Count != 0 {
t.Errorf("expected 0 count for nonexistent, got %d", baseline.Count)
}
}
func TestCalculateZScore(t *testing.T) {
baseline := Baseline{Mean: 30.0, StdDev: 5.0, Count: 100}
// Normal value (Z = 1.0).
z := CalculateZScore(35.0, baseline)
if math.Abs(z-1.0) > 0.01 {
t.Errorf("expected Z≈1.0, got %f", z)
}
// Anomalous value (Z = 4.0).
z = CalculateZScore(50.0, baseline)
if math.Abs(z-4.0) > 0.01 {
t.Errorf("expected Z≈4.0, got %f", z)
}
// Insufficient data → 0.
z = CalculateZScore(50.0, Baseline{Mean: 30, StdDev: 5, Count: 5})
if z != 0 {
t.Errorf("expected 0 for insufficient data, got %f", z)
}
}
func TestIsAnomaly(t *testing.T) {
baseline := Baseline{Mean: 30.0, StdDev: 5.0, Count: 100}
if IsAnomaly(35.0, baseline, 3.0) {
t.Error("35 should not be anomaly (Z=1.0)")
}
if !IsAnomaly(50.0, baseline, 3.0) {
t.Error("50 should be anomaly (Z=4.0)")
}
if !IsAnomaly(10.0, baseline, 3.0) {
t.Error("10 should be anomaly (Z=-4.0)")
}
}
func TestMetricsDB_Purge(t *testing.T) {
db := NewMetricsDB(100*time.Millisecond, 100)
db.AddDataPoint("comp", "cpu", 50)
time.Sleep(150 * time.Millisecond)
db.AddDataPoint("comp", "cpu", 60)
removed := db.Purge()
if removed != 1 {
t.Errorf("expected 1 purged, got %d", removed)
}
}
func TestMetricsDB_GetRecent(t *testing.T) {
db := NewMetricsDB(time.Hour, 100)
for i := 0; i < 10; i++ {
db.AddDataPoint("comp", "mem", float64(i*10))
}
recent := db.GetRecent("comp", "mem", 3)
if len(recent) != 3 {
t.Fatalf("expected 3 recent, got %d", len(recent))
}
// Should be last 3: 70, 80, 90.
if recent[0].Value != 70 || recent[2].Value != 90 {
t.Errorf("unexpected recent values: %v", recent)
}
}
// --- MockCollector for HealthMonitor tests ---
type mockCollector struct {
results map[string]map[string]float64
errors map[string]error
}
func (m *mockCollector) Collect(_ context.Context, component string) (map[string]float64, error) {
if err, ok := m.errors[component]; ok && err != nil {
return nil, err
}
if metrics, ok := m.results[component]; ok {
return metrics, nil
}
return map[string]float64{}, nil
}
// --- HealthMonitor Tests ---
// HM-01: Normal health check — all HEALTHY.
func TestHealthMonitor_HM01_AllHealthy(t *testing.T) {
hm := NewHealthMonitor(&mockCollector{}, 10)
registerTestComponents(hm, 6)
health := hm.GetHealth()
if health.OverallStatus != OverallHealthy {
t.Errorf("expected HEALTHY, got %s", health.OverallStatus)
}
if !health.QuorumValid {
t.Error("expected quorum valid")
}
if len(health.Components) != 6 {
t.Errorf("expected 6 components, got %d", len(health.Components))
}
}
// HM-02: Single component DEGRADED.
func TestHealthMonitor_HM02_SingleDegraded(t *testing.T) {
hm := NewHealthMonitor(&mockCollector{}, 10)
registerTestComponents(hm, 6)
hm.SetComponentStatus("comp-0", StatusDegraded)
health := hm.GetHealth()
if health.OverallStatus != OverallDegraded {
t.Errorf("expected DEGRADED, got %s", health.OverallStatus)
}
if !health.QuorumValid {
t.Error("expected quorum still valid with 5/6 healthy")
}
}
// HM-03: Multiple components CRITICAL → quorum lost.
func TestHealthMonitor_HM03_MultipleCritical(t *testing.T) {
hm := NewHealthMonitor(&mockCollector{}, 10)
registerTestComponents(hm, 6)
hm.SetComponentStatus("comp-0", StatusCritical)
hm.SetComponentStatus("comp-1", StatusCritical)
hm.SetComponentStatus("comp-2", StatusCritical)
health := hm.GetHealth()
if health.OverallStatus != OverallCritical {
t.Errorf("expected CRITICAL, got %s", health.OverallStatus)
}
if health.QuorumValid {
t.Error("expected quorum INVALID with 3/6 critical")
}
}
// HM-04: Anomaly detection (CPU spike).
func TestHealthMonitor_HM04_CPUAnomaly(t *testing.T) {
hm := NewHealthMonitor(&mockCollector{}, 100)
hm.RegisterComponent(ComponentConfig{
Name: "soc-ingest",
Type: "go_binary",
Thresholds: map[string]float64{"cpu": 80},
ThresholdIsMax: map[string]bool{"cpu": true},
})
// Build baseline of normal CPU (30%).
for i := 0; i < 50; i++ {
hm.metricsDB.AddDataPoint("soc-ingest", "cpu", 30.0)
}
// Spike to 95%.
hm.UpdateMetrics("soc-ingest", map[string]float64{"cpu": 95.0})
hm.checkHealth()
// Should have alert(s).
select {
case alert := <-hm.alertBus:
if alert.Component != "soc-ingest" {
t.Errorf("expected soc-ingest, got %s", alert.Component)
}
if alert.Metric != "cpu" {
t.Errorf("expected cpu metric, got %s", alert.Metric)
}
default:
t.Error("expected alert for CPU spike")
}
}
// HM-05: Memory leak detection.
func TestHealthMonitor_HM05_MemoryLeak(t *testing.T) {
hm := NewHealthMonitor(&mockCollector{}, 100)
hm.RegisterComponent(ComponentConfig{
Name: "soc-correlate",
Type: "go_binary",
Thresholds: map[string]float64{"memory": 90},
ThresholdIsMax: map[string]bool{"memory": true},
})
// Build baseline of normal memory (40%).
for i := 0; i < 50; i++ {
hm.metricsDB.AddDataPoint("soc-correlate", "memory", 40.0)
}
// Memory spike to 95%.
hm.UpdateMetrics("soc-correlate", map[string]float64{"memory": 95.0})
hm.checkHealth()
select {
case alert := <-hm.alertBus:
if alert.Metric != "memory" {
t.Errorf("expected memory metric, got %s", alert.Metric)
}
default:
t.Error("expected alert for memory spike")
}
}
// HM-06: Quorum validation failure.
func TestHealthMonitor_HM06_QuorumFailure(t *testing.T) {
statuses := map[string]ComponentStatus{
"a": StatusOffline,
"b": StatusOffline,
"c": StatusOffline,
"d": StatusOffline,
"e": StatusHealthy,
"f": StatusHealthy,
}
if ValidateQuorum(statuses) {
t.Error("expected quorum invalid with 4/6 offline")
}
}
// HM-06b: Quorum validation success (edge case: exactly 2/3).
func TestHealthMonitor_HM06b_QuorumEdge(t *testing.T) {
statuses := map[string]ComponentStatus{
"a": StatusHealthy,
"b": StatusHealthy,
"c": StatusCritical,
}
if !ValidateQuorum(statuses) {
t.Error("expected quorum valid with 2/3 healthy (exact threshold)")
}
}
// HM-06c: Empty quorum.
func TestHealthMonitor_HM06c_EmptyQuorum(t *testing.T) {
if ValidateQuorum(map[string]ComponentStatus{}) {
t.Error("expected quorum invalid with 0 components")
}
}
// HM-07: Metrics collection (no data loss).
func TestHealthMonitor_HM07_MetricsCollection(t *testing.T) {
collector := &mockCollector{
results: map[string]map[string]float64{
"comp-0": {"cpu": 25, "memory": 40},
},
}
hm := NewHealthMonitor(collector, 10)
hm.RegisterComponent(ComponentConfig{Name: "comp-0", Type: "go_binary"})
hm.collectMetrics(context.Background())
hm.mu.RLock()
comp := hm.components["comp-0"]
hm.mu.RUnlock()
if comp.Metrics["cpu"] != 25 {
t.Errorf("expected cpu=25, got %f", comp.Metrics["cpu"])
}
if comp.Metrics["memory"] != 40 {
t.Errorf("expected memory=40, got %f", comp.Metrics["memory"])
}
}
// HM-07b: Collection error increments consecutive failures.
func TestHealthMonitor_HM07b_CollectionError(t *testing.T) {
collector := &mockCollector{
errors: map[string]error{
"comp-0": fmt.Errorf("connection refused"),
},
}
hm := NewHealthMonitor(collector, 10)
hm.RegisterComponent(ComponentConfig{Name: "comp-0", Type: "go_binary"})
hm.collectMetrics(context.Background())
hm.mu.RLock()
comp := hm.components["comp-0"]
hm.mu.RUnlock()
if comp.Consecutive != 1 {
t.Errorf("expected 1 consecutive failure, got %d", comp.Consecutive)
}
}
// HM-08: Alert bus fan-out (non-blocking).
func TestHealthMonitor_HM08_AlertBusFanOut(t *testing.T) {
hm := NewHealthMonitor(&mockCollector{}, 5)
hm.RegisterComponent(ComponentConfig{
Name: "comp",
Type: "go_binary",
Thresholds: map[string]float64{"cpu": 50},
ThresholdIsMax: map[string]bool{"cpu": true},
})
// Fill alert bus.
for i := 0; i < 5; i++ {
hm.alertBus <- HealthAlert{Component: fmt.Sprintf("test-%d", i)}
}
// Emit one more — should be dropped (non-blocking).
hm.emitAlert(HealthAlert{Component: "overflow"})
// No panic = success.
}
// Test GetHealth returns a deep copy.
func TestHealthMonitor_GetHealthDeepCopy(t *testing.T) {
hm := NewHealthMonitor(&mockCollector{}, 10)
hm.RegisterComponent(ComponentConfig{Name: "test", Type: "go_binary"})
hm.UpdateMetrics("test", map[string]float64{"cpu": 50})
health := hm.GetHealth()
health.Components[0].Metrics["cpu"] = 999
// Original should be unchanged.
hm.mu.RLock()
original := hm.components["test"].Metrics["cpu"]
hm.mu.RUnlock()
if original != 50 {
t.Errorf("deep copy failed: original modified to %f", original)
}
}
// Test threshold breach transitions status to DEGRADED then CRITICAL.
func TestHealthMonitor_StatusTransitions(t *testing.T) {
hm := NewHealthMonitor(&mockCollector{}, 100)
hm.RegisterComponent(ComponentConfig{
Name: "comp",
Type: "go_binary",
Thresholds: map[string]float64{"error_rate": 5},
ThresholdIsMax: map[string]bool{"error_rate": true},
})
// Breach once → DEGRADED.
hm.UpdateMetrics("comp", map[string]float64{"error_rate": 10})
hm.checkHealth()
hm.mu.RLock()
status := hm.components["comp"].Status
hm.mu.RUnlock()
if status != StatusDegraded {
t.Errorf("expected DEGRADED after 1 breach, got %s", status)
}
// Breach 3× → CRITICAL.
for i := 0; i < 3; i++ {
hm.checkHealth()
}
hm.mu.RLock()
status = hm.components["comp"].Status
hm.mu.RUnlock()
if status != StatusCritical {
t.Errorf("expected CRITICAL after repeated breaches, got %s", status)
}
}
// Test lower-bound threshold (ThresholdIsMax=false).
func TestHealthMonitor_LowerBoundThreshold(t *testing.T) {
hm := NewHealthMonitor(&mockCollector{}, 100)
hm.RegisterComponent(ComponentConfig{
Name: "immune",
Type: "c_kernel_module",
Thresholds: map[string]float64{"hooks_active": 10},
ThresholdIsMax: map[string]bool{"hooks_active": false},
})
// hooks_active = 5 (below threshold of 10) → warning.
hm.UpdateMetrics("immune", map[string]float64{"hooks_active": 5})
hm.checkHealth()
select {
case alert := <-hm.alertBus:
if alert.Component != "immune" || alert.Metric != "hooks_active" {
t.Errorf("unexpected alert: %+v", alert)
}
default:
t.Error("expected alert for hooks_active below threshold")
}
}
// Test ComponentCount.
func TestHealthMonitor_ComponentCount(t *testing.T) {
hm := NewHealthMonitor(&mockCollector{}, 10)
if hm.ComponentCount() != 0 {
t.Error("expected 0 initially")
}
registerTestComponents(hm, 4)
if hm.ComponentCount() != 4 {
t.Errorf("expected 4, got %d", hm.ComponentCount())
}
}
// Test Start/Stop lifecycle.
func TestHealthMonitor_StartStop(t *testing.T) {
hm := NewHealthMonitor(&mockCollector{}, 10)
registerTestComponents(hm, 2)
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
hm.Start(ctx)
close(done)
}()
// Let it run briefly.
time.Sleep(50 * time.Millisecond)
cancel()
select {
case <-done:
// Clean shutdown.
case <-time.After(time.Second):
t.Fatal("Start() did not return after context cancellation")
}
}
// --- Helpers ---
func registerTestComponents(hm *HealthMonitor, n int) {
for i := 0; i < n; i++ {
hm.RegisterComponent(ComponentConfig{
Name: fmt.Sprintf("comp-%d", i),
Type: "go_binary",
})
}
}

View file

@ -0,0 +1,247 @@
package resilience
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"log/slog"
"os"
"sync"
"time"
)
// IntegrityStatus represents the result of an integrity check.
type IntegrityStatus string
const (
IntegrityVerified IntegrityStatus = "VERIFIED"
IntegrityCompromised IntegrityStatus = "COMPROMISED"
IntegrityUnknown IntegrityStatus = "UNKNOWN"
)
// IntegrityReport is the full result of an integrity verification.
type IntegrityReport struct {
Overall IntegrityStatus `json:"overall"`
Timestamp time.Time `json:"timestamp"`
Binaries map[string]BinaryStatus `json:"binaries,omitempty"`
Chain *ChainStatus `json:"chain,omitempty"`
Configs map[string]ConfigStatus `json:"configs,omitempty"`
}
// BinaryStatus is the integrity status of a single binary.
type BinaryStatus struct {
Status IntegrityStatus `json:"status"`
Expected string `json:"expected"`
Current string `json:"current"`
}
// ChainStatus is the integrity status of the decision chain.
type ChainStatus struct {
Valid bool `json:"valid"`
Error string `json:"error,omitempty"`
BreakPoint int `json:"break_point,omitempty"`
Entries int `json:"entries"`
}
// ConfigStatus is the integrity status of a config file.
type ConfigStatus struct {
Valid bool `json:"valid"`
Error string `json:"error,omitempty"`
StoredHMAC string `json:"stored_hmac,omitempty"`
CurrentHMAC string `json:"current_hmac,omitempty"`
}
// IntegrityVerifier performs periodic integrity checks on binaries,
// decision chain, and config files.
type IntegrityVerifier struct {
mu sync.RWMutex
binaryHashes map[string]string // path → expected SHA-256
configPaths []string // config files to verify
hmacKey []byte // key for config HMAC-SHA256
chainPath string // path to decision chain log
logger *slog.Logger
lastReport *IntegrityReport
}
// NewIntegrityVerifier creates a new integrity verifier.
func NewIntegrityVerifier(hmacKey []byte) *IntegrityVerifier {
return &IntegrityVerifier{
binaryHashes: make(map[string]string),
hmacKey: hmacKey,
logger: slog.Default().With("component", "sarl-integrity"),
}
}
// RegisterBinary adds a binary with its expected SHA-256 hash.
func (iv *IntegrityVerifier) RegisterBinary(path, expectedHash string) {
iv.mu.Lock()
defer iv.mu.Unlock()
iv.binaryHashes[path] = expectedHash
}
// RegisterConfig adds a config file to verify.
func (iv *IntegrityVerifier) RegisterConfig(path string) {
iv.mu.Lock()
defer iv.mu.Unlock()
iv.configPaths = append(iv.configPaths, path)
}
// SetChainPath sets the decision chain log path.
func (iv *IntegrityVerifier) SetChainPath(path string) {
iv.mu.Lock()
defer iv.mu.Unlock()
iv.chainPath = path
}
// VerifyAll runs all integrity checks and returns a comprehensive report.
// Note: file I/O (binary hashing, config reading) is done WITHOUT holding
// the mutex to prevent thread starvation on slow storage.
func (iv *IntegrityVerifier) VerifyAll() IntegrityReport {
report := IntegrityReport{
Overall: IntegrityVerified,
Timestamp: time.Now(),
Binaries: make(map[string]BinaryStatus),
Configs: make(map[string]ConfigStatus),
}
// Snapshot config under lock, then release before I/O.
iv.mu.RLock()
binaryHashesCopy := make(map[string]string, len(iv.binaryHashes))
for k, v := range iv.binaryHashes {
binaryHashesCopy[k] = v
}
configPathsCopy := make([]string, len(iv.configPaths))
copy(configPathsCopy, iv.configPaths)
hmacKeyCopy := make([]byte, len(iv.hmacKey))
copy(hmacKeyCopy, iv.hmacKey)
chainPath := iv.chainPath
iv.mu.RUnlock()
// Check binaries (file I/O — no lock held).
for path, expected := range binaryHashesCopy {
status := iv.verifyBinary(path, expected)
report.Binaries[path] = status
if status.Status == IntegrityCompromised {
report.Overall = IntegrityCompromised
}
}
// Check configs (file I/O — no lock held).
for _, path := range configPathsCopy {
status := iv.verifyConfigFile(path)
report.Configs[path] = status
if !status.Valid {
report.Overall = IntegrityCompromised
}
}
// Check decision chain (file I/O — no lock held).
if chainPath != "" {
chain := iv.verifyDecisionChain(chainPath)
report.Chain = &chain
if !chain.Valid {
report.Overall = IntegrityCompromised
}
}
iv.mu.Lock()
iv.lastReport = &report
iv.mu.Unlock()
if report.Overall == IntegrityCompromised {
iv.logger.Error("INTEGRITY COMPROMISED", "report", report)
} else {
iv.logger.Debug("integrity verified", "binaries", len(report.Binaries))
}
return report
}
// LastReport returns the most recent integrity report.
func (iv *IntegrityVerifier) LastReport() *IntegrityReport {
iv.mu.RLock()
defer iv.mu.RUnlock()
return iv.lastReport
}
// verifyBinary calculates SHA-256 of a file and compares to expected.
func (iv *IntegrityVerifier) verifyBinary(path, expected string) BinaryStatus {
current, err := fileSHA256(path)
if err != nil {
return BinaryStatus{
Status: IntegrityUnknown,
Expected: expected,
Current: fmt.Sprintf("error: %v", err),
}
}
if current != expected {
return BinaryStatus{
Status: IntegrityCompromised,
Expected: expected,
Current: current,
}
}
return BinaryStatus{
Status: IntegrityVerified,
Expected: expected,
Current: current,
}
}
// verifyConfigFile checks HMAC-SHA256 of a config file.
func (iv *IntegrityVerifier) verifyConfigFile(path string) ConfigStatus {
data, err := os.ReadFile(path)
if err != nil {
return ConfigStatus{Valid: false, Error: fmt.Sprintf("unreadable: %v", err)}
}
currentHMAC := computeHMAC(data, iv.hmacKey)
// For now, we just verify the file is readable and compute HMAC.
// In production, the stored HMAC would be extracted from a sidecar file.
return ConfigStatus{
Valid: true,
CurrentHMAC: currentHMAC,
}
}
// verifyDecisionChain verifies the SHA-256 hash chain in the decision log.
func (iv *IntegrityVerifier) verifyDecisionChain(path string) ChainStatus {
_, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return ChainStatus{Valid: true, Entries: 0} // No chain yet.
}
return ChainStatus{Valid: false, Error: fmt.Sprintf("unreadable: %v", err)}
}
// In a real implementation, we'd parse the chain entries and verify
// that each entry's hash includes the previous entry's hash.
// For now, verify the file exists and is readable.
return ChainStatus{Valid: true}
}
// fileSHA256 computes the SHA-256 hash of a file.
func fileSHA256(path string) (string, error) {
f, err := os.Open(path)
if err != nil {
return "", err
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return "", err
}
return hex.EncodeToString(h.Sum(nil)), nil
}
// computeHMAC computes HMAC-SHA256 of data with the given key.
func computeHMAC(data, key []byte) string {
mac := hmac.New(sha256.New, key)
mac.Write(data)
return hex.EncodeToString(mac.Sum(nil))
}

View file

@ -0,0 +1,283 @@
// Package resilience implements the Sentinel Autonomous Resilience Layer (SARL).
//
// Five levels of autonomous self-recovery:
//
// L1 — Self-Monitoring: health checks, quorum, anomaly detection
// L2 — Self-Healing: restart, rollback, recovery strategies
// L3 — Self-Preservation: emergency modes (safe/lockdown/apoptosis)
// L4 — Immune Integration: behavioral anomaly detection
// L5 — Autonomous Recovery: playbooks for resurrection, consensus, crypto
package resilience
import (
"math"
"sync"
"time"
)
// MetricsDB provides an in-memory time-series store with ring buffers
// for each component/metric pair. Supports rolling baselines (mean/stddev)
// for Z-score anomaly detection.
type MetricsDB struct {
mu sync.RWMutex
series map[string]*RingBuffer // key = "component:metric"
window time.Duration // retention window (default 1h)
maxSize int // max data points per series
}
// DataPoint is a single timestamped metric value.
type DataPoint struct {
Timestamp time.Time `json:"timestamp"`
Value float64 `json:"value"`
}
// Baseline holds rolling statistics for anomaly detection.
type Baseline struct {
Mean float64 `json:"mean"`
StdDev float64 `json:"std_dev"`
Count int `json:"count"`
Min float64 `json:"min"`
Max float64 `json:"max"`
}
// RingBuffer is a fixed-size circular buffer for DataPoints.
type RingBuffer struct {
data []DataPoint
head int
count int
size int
}
// DefaultMetricsWindow is the default retention window (1 hour).
const DefaultMetricsWindow = 1 * time.Hour
// DefaultMetricsMaxSize is the default max points per series (1h / 10s = 360).
const DefaultMetricsMaxSize = 360
// NewMetricsDB creates a new in-memory time-series store.
func NewMetricsDB(window time.Duration, maxSize int) *MetricsDB {
if window <= 0 {
window = DefaultMetricsWindow
}
if maxSize <= 0 {
maxSize = DefaultMetricsMaxSize
}
return &MetricsDB{
series: make(map[string]*RingBuffer),
window: window,
maxSize: maxSize,
}
}
// AddDataPoint records a metric value for a component.
func (db *MetricsDB) AddDataPoint(component, metric string, value float64) {
key := component + ":" + metric
db.mu.Lock()
defer db.mu.Unlock()
rb, ok := db.series[key]
if !ok {
rb = newRingBuffer(db.maxSize)
db.series[key] = rb
}
rb.Add(DataPoint{Timestamp: time.Now(), Value: value})
}
// GetBaseline returns rolling mean/stddev for a component metric
// calculated over the specified window duration.
func (db *MetricsDB) GetBaseline(component, metric string, window time.Duration) Baseline {
key := component + ":" + metric
db.mu.RLock()
defer db.mu.RUnlock()
rb, ok := db.series[key]
if !ok {
return Baseline{}
}
cutoff := time.Now().Add(-window)
points := rb.After(cutoff)
if len(points) == 0 {
return Baseline{}
}
return calculateBaseline(points)
}
// GetRecent returns the most recent N data points for a component metric.
func (db *MetricsDB) GetRecent(component, metric string, n int) []DataPoint {
key := component + ":" + metric
db.mu.RLock()
defer db.mu.RUnlock()
rb, ok := db.series[key]
if !ok {
return nil
}
all := rb.All()
if len(all) <= n {
return all
}
return all[len(all)-n:]
}
// CalculateZScore returns the Z-score for a value against the baseline.
// Returns 0 if baseline has insufficient data or zero stddev.
func CalculateZScore(value float64, baseline Baseline) float64 {
if baseline.Count < 10 || baseline.StdDev == 0 {
return 0
}
return (value - baseline.Mean) / baseline.StdDev
}
// IsAnomaly returns true if the Z-score exceeds the threshold (default 3.0).
func IsAnomaly(value float64, baseline Baseline, threshold float64) bool {
if threshold <= 0 {
threshold = 3.0
}
zscore := CalculateZScore(value, baseline)
return math.Abs(zscore) > threshold
}
// SeriesCount returns the number of tracked series.
func (db *MetricsDB) SeriesCount() int {
db.mu.RLock()
defer db.mu.RUnlock()
return len(db.series)
}
// Purge removes data points older than the retention window.
func (db *MetricsDB) Purge() int {
db.mu.Lock()
defer db.mu.Unlock()
cutoff := time.Now().Add(-db.window)
total := 0
for key, rb := range db.series {
removed := rb.RemoveBefore(cutoff)
total += removed
if rb.Len() == 0 {
delete(db.series, key)
}
}
return total
}
// --- RingBuffer implementation ---
func newRingBuffer(size int) *RingBuffer {
return &RingBuffer{
data: make([]DataPoint, size),
size: size,
}
}
// Add inserts a DataPoint, overwriting the oldest if full.
func (rb *RingBuffer) Add(dp DataPoint) {
rb.data[rb.head] = dp
rb.head = (rb.head + 1) % rb.size
if rb.count < rb.size {
rb.count++
}
}
// Len returns the number of data points in the buffer.
func (rb *RingBuffer) Len() int {
return rb.count
}
// All returns all data points in chronological order.
func (rb *RingBuffer) All() []DataPoint {
if rb.count == 0 {
return nil
}
result := make([]DataPoint, rb.count)
if rb.count < rb.size {
// Buffer not yet full — data starts at 0.
copy(result, rb.data[:rb.count])
} else {
// Buffer wrapped — oldest is at head.
n := copy(result, rb.data[rb.head:rb.size])
copy(result[n:], rb.data[:rb.head])
}
return result
}
// After returns points with timestamp after the cutoff.
func (rb *RingBuffer) After(cutoff time.Time) []DataPoint {
all := rb.All()
result := make([]DataPoint, 0, len(all))
for _, dp := range all {
if dp.Timestamp.After(cutoff) {
result = append(result, dp)
}
}
return result
}
// RemoveBefore removes data points before the cutoff by compacting.
// Returns the number of points removed.
func (rb *RingBuffer) RemoveBefore(cutoff time.Time) int {
all := rb.All()
kept := make([]DataPoint, 0, len(all))
for _, dp := range all {
if !dp.Timestamp.Before(cutoff) {
kept = append(kept, dp)
}
}
removed := len(all) - len(kept)
if removed == 0 {
return 0
}
// Rebuild the ring buffer with kept data.
rb.count = 0
rb.head = 0
for _, dp := range kept {
rb.Add(dp)
}
return removed
}
// --- Statistics ---
func calculateBaseline(points []DataPoint) Baseline {
n := len(points)
if n == 0 {
return Baseline{}
}
var sum, min, max float64
min = points[0].Value
max = points[0].Value
for _, p := range points {
sum += p.Value
if p.Value < min {
min = p.Value
}
if p.Value > max {
max = p.Value
}
}
mean := sum / float64(n)
var variance float64
for _, p := range points {
diff := p.Value - mean
variance += diff * diff
}
variance /= float64(n)
return Baseline{
Mean: mean,
StdDev: math.Sqrt(variance),
Count: n,
Min: min,
Max: max,
}
}

View file

@ -0,0 +1,290 @@
package resilience
import (
"fmt"
"log/slog"
"sync"
"time"
)
// EmergencyMode defines the system's emergency state.
type EmergencyMode string
const (
ModeNone EmergencyMode = "NONE"
ModeSafe EmergencyMode = "SAFE"
ModeLockdown EmergencyMode = "LOCKDOWN"
ModeApoptosis EmergencyMode = "APOPTOSIS"
)
// ModeActivation records when and why a mode was activated.
type ModeActivation struct {
Mode EmergencyMode `json:"mode"`
ActivatedAt time.Time `json:"activated_at"`
ActivatedBy string `json:"activated_by"` // "auto" or "architect:<name>"
Reason string `json:"reason"`
AutoExit bool `json:"auto_exit"`
AutoExitAt time.Time `json:"auto_exit_at,omitempty"`
}
// PreservationEvent is an audit log entry for preservation actions.
type PreservationEvent struct {
Timestamp time.Time `json:"timestamp"`
Mode EmergencyMode `json:"mode"`
Action string `json:"action"`
Detail string `json:"detail"`
Success bool `json:"success"`
Error string `json:"error,omitempty"`
}
// ModeActionFunc is a callback to perform mode-specific actions.
// Implementations handle the real system operations (network isolation, process freeze, etc.).
type ModeActionFunc func(mode EmergencyMode, action string, params map[string]interface{}) error
// PreservationEngine manages emergency modes (safe/lockdown/apoptosis).
type PreservationEngine struct {
mu sync.RWMutex
currentMode EmergencyMode
activation *ModeActivation
history []PreservationEvent
actionFn ModeActionFunc
integrityFn func() IntegrityReport // pluggable integrity check
logger *slog.Logger
}
// NewPreservationEngine creates a new preservation engine.
func NewPreservationEngine(actionFn ModeActionFunc) *PreservationEngine {
return &PreservationEngine{
currentMode: ModeNone,
history: make([]PreservationEvent, 0),
actionFn: actionFn,
logger: slog.Default().With("component", "sarl-preservation"),
}
}
// CurrentMode returns the active emergency mode.
func (pe *PreservationEngine) CurrentMode() EmergencyMode {
pe.mu.RLock()
defer pe.mu.RUnlock()
return pe.currentMode
}
// Activation returns the current mode activation details (nil if NONE).
func (pe *PreservationEngine) Activation() *ModeActivation {
pe.mu.RLock()
defer pe.mu.RUnlock()
if pe.activation == nil {
return nil
}
cp := *pe.activation
return &cp
}
// ActivateMode enters an emergency mode. Returns error if transition is invalid.
func (pe *PreservationEngine) ActivateMode(mode EmergencyMode, reason, activatedBy string) error {
pe.mu.Lock()
defer pe.mu.Unlock()
if mode == ModeNone {
return fmt.Errorf("use DeactivateMode to exit emergency mode")
}
// Validate transitions: can always escalate, can't downgrade.
if !pe.isValidTransition(pe.currentMode, mode) {
return fmt.Errorf("invalid transition: %s → %s", pe.currentMode, mode)
}
pe.logger.Warn("EMERGENCY MODE ACTIVATION",
"mode", mode,
"reason", reason,
"activated_by", activatedBy,
)
// Execute mode-specific actions.
actions := pe.actionsForMode(mode)
for _, action := range actions {
err := pe.executeAction(mode, action.name, action.params)
if err != nil {
pe.logger.Error("mode action failed",
"mode", mode,
"action", action.name,
"error", err,
)
// In critical modes, continue despite errors.
if mode != ModeApoptosis {
return fmt.Errorf("failed to activate %s: action %s: %w", mode, action.name, err)
}
}
}
activation := &ModeActivation{
Mode: mode,
ActivatedAt: time.Now(),
ActivatedBy: activatedBy,
Reason: reason,
}
if mode == ModeSafe {
activation.AutoExit = true
activation.AutoExitAt = time.Now().Add(15 * time.Minute)
}
pe.currentMode = mode
pe.activation = activation
return nil
}
// DeactivateMode exits the current emergency mode and returns to NONE.
func (pe *PreservationEngine) DeactivateMode(deactivatedBy string) error {
pe.mu.Lock()
defer pe.mu.Unlock()
if pe.currentMode == ModeNone {
return nil
}
// Lockdown and apoptosis require manual deactivation by architect.
if pe.currentMode == ModeApoptosis {
return fmt.Errorf("apoptosis mode cannot be deactivated — system rebuild required")
}
pe.logger.Info("EMERGENCY MODE DEACTIVATION",
"mode", pe.currentMode,
"deactivated_by", deactivatedBy,
)
pe.recordEvent(pe.currentMode, "deactivated",
fmt.Sprintf("deactivated by %s", deactivatedBy), true, "")
pe.currentMode = ModeNone
pe.activation = nil
return nil
}
// ShouldAutoExit checks if safe mode should auto-exit based on timer.
func (pe *PreservationEngine) ShouldAutoExit() bool {
pe.mu.RLock()
defer pe.mu.RUnlock()
if pe.currentMode != ModeSafe || pe.activation == nil {
return false
}
return pe.activation.AutoExit && time.Now().After(pe.activation.AutoExitAt)
}
// isValidTransition checks if a mode transition is allowed.
// Escalation order: NONE → SAFE → LOCKDOWN → APOPTOSIS.
func (pe *PreservationEngine) isValidTransition(from, to EmergencyMode) bool {
rank := map[EmergencyMode]int{
ModeNone: 0,
ModeSafe: 1,
ModeLockdown: 2,
ModeApoptosis: 3,
}
// Can always escalate or re-enter same mode.
return rank[to] >= rank[from]
}
type modeAction struct {
name string
params map[string]interface{}
}
// actionsForMode returns the actions to execute for a given mode.
func (pe *PreservationEngine) actionsForMode(mode EmergencyMode) []modeAction {
switch mode {
case ModeSafe:
return []modeAction{
{"disable_non_essential_services", map[string]interface{}{
"services": []string{"analytics", "reporting", "p2p_sync", "threat_intel_feeds"},
}},
{"enable_readonly_mode", map[string]interface{}{
"scope": []string{"event_ingest", "correlation", "dashboard_view"},
}},
{"preserve_all_logs", nil},
{"notify_architect", map[string]interface{}{"severity": "emergency"}},
{"increase_monitoring_frequency", map[string]interface{}{"interval": "5s"}},
}
case ModeLockdown:
return []modeAction{
{"isolate_from_network", map[string]interface{}{"scope": "all_external"}},
{"freeze_all_processes", nil},
{"capture_memory_dump", nil},
{"capture_disk_snapshot", nil},
{"trigger_immune_kernel_lock", map[string]interface{}{
"allow_syscalls": []string{"read", "write", "exit"},
}},
{"send_panic_alert", map[string]interface{}{
"channels": []string{"email", "sms", "slack", "pagerduty"},
}},
}
case ModeApoptosis:
return []modeAction{
{"graceful_shutdown", map[string]interface{}{"timeout": "30s", "drain_events": true}},
{"zero_sensitive_memory", map[string]interface{}{
"regions": []string{"keys", "certs", "tokens", "secrets"},
}},
{"preserve_forensic_evidence", nil},
{"notify_soc", map[string]interface{}{
"severity": "CRITICAL",
"message": "system self-terminated",
}},
{"secure_erase_temp_files", nil},
}
}
return nil
}
// executeAction runs a mode action and records the result.
func (pe *PreservationEngine) executeAction(mode EmergencyMode, name string, params map[string]interface{}) error {
err := pe.actionFn(mode, name, params)
success := err == nil
errStr := ""
if err != nil {
errStr = err.Error()
}
pe.recordEvent(mode, name, fmt.Sprintf("params: %v", params), success, errStr)
return err
}
// recordEvent appends to the audit history.
func (pe *PreservationEngine) recordEvent(mode EmergencyMode, action, detail string, success bool, errStr string) {
pe.history = append(pe.history, PreservationEvent{
Timestamp: time.Now(),
Mode: mode,
Action: action,
Detail: detail,
Success: success,
Error: errStr,
})
}
// History returns the preservation audit log.
func (pe *PreservationEngine) History() []PreservationEvent {
pe.mu.RLock()
defer pe.mu.RUnlock()
result := make([]PreservationEvent, len(pe.history))
copy(result, pe.history)
return result
}
// SetIntegrityCheck sets the pluggable integrity checker.
func (pe *PreservationEngine) SetIntegrityCheck(fn func() IntegrityReport) {
pe.mu.Lock()
defer pe.mu.Unlock()
pe.integrityFn = fn
}
// CheckIntegrity runs the pluggable integrity check and returns the report.
func (pe *PreservationEngine) CheckIntegrity() IntegrityReport {
pe.mu.RLock()
fn := pe.integrityFn
pe.mu.RUnlock()
if fn == nil {
return IntegrityReport{Overall: IntegrityVerified, Timestamp: time.Now()}
}
return fn()
}

View file

@ -0,0 +1,439 @@
package resilience
import (
"crypto/sha256"
"encoding/hex"
"os"
"path/filepath"
"testing"
"time"
)
// --- Mock action function ---
type modeActionLog struct {
calls []struct {
mode EmergencyMode
action string
}
failAction string // if set, this action will fail
}
func newModeActionLog() *modeActionLog {
return &modeActionLog{}
}
func (m *modeActionLog) execute(mode EmergencyMode, action string, _ map[string]interface{}) error {
m.calls = append(m.calls, struct {
mode EmergencyMode
action string
}{mode, action})
if m.failAction == action {
return errActionFailed
}
return nil
}
var errActionFailed = &actionError{"simulated failure"}
type actionError struct{ msg string }
func (e *actionError) Error() string { return e.msg }
// --- Preservation Engine Tests ---
// SP-01: Safe mode activation.
func TestPreservation_SP01_SafeMode(t *testing.T) {
log := newModeActionLog()
pe := NewPreservationEngine(log.execute)
err := pe.ActivateMode(ModeSafe, "quorum lost (3/6 offline)", "auto")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pe.CurrentMode() != ModeSafe {
t.Errorf("expected SAFE, got %s", pe.CurrentMode())
}
activation := pe.Activation()
if activation == nil {
t.Fatal("expected activation details")
}
if !activation.AutoExit {
t.Error("safe mode should have auto-exit enabled")
}
// Should have executed safe mode actions.
if len(log.calls) == 0 {
t.Error("expected mode actions to be executed")
}
// First action should be disable_non_essential_services.
if log.calls[0].action != "disable_non_essential_services" {
t.Errorf("expected first action disable_non_essential_services, got %s", log.calls[0].action)
}
}
// SP-02: Lockdown mode activation.
func TestPreservation_SP02_LockdownMode(t *testing.T) {
log := newModeActionLog()
pe := NewPreservationEngine(log.execute)
err := pe.ActivateMode(ModeLockdown, "binary tampering detected", "auto")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pe.CurrentMode() != ModeLockdown {
t.Errorf("expected LOCKDOWN, got %s", pe.CurrentMode())
}
// Should have network isolation action.
foundIsolate := false
for _, c := range log.calls {
if c.action == "isolate_from_network" {
foundIsolate = true
}
}
if !foundIsolate {
t.Error("expected isolate_from_network in lockdown actions")
}
}
// SP-03: Apoptosis mode activation.
func TestPreservation_SP03_ApoptosisMode(t *testing.T) {
log := newModeActionLog()
pe := NewPreservationEngine(log.execute)
err := pe.ActivateMode(ModeApoptosis, "rootkit detected", "architect:admin")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pe.CurrentMode() != ModeApoptosis {
t.Errorf("expected APOPTOSIS, got %s", pe.CurrentMode())
}
// Should have graceful_shutdown action.
foundShutdown := false
for _, c := range log.calls {
if c.action == "graceful_shutdown" {
foundShutdown = true
}
}
if !foundShutdown {
t.Error("expected graceful_shutdown in apoptosis actions")
}
// Cannot deactivate apoptosis.
err = pe.DeactivateMode("architect:admin")
if err == nil {
t.Error("expected error deactivating apoptosis")
}
}
// SP-04: Invalid transition (downgrade).
func TestPreservation_SP04_InvalidTransition(t *testing.T) {
log := newModeActionLog()
pe := NewPreservationEngine(log.execute)
pe.ActivateMode(ModeLockdown, "test", "auto")
// Can't downgrade from LOCKDOWN to SAFE.
err := pe.ActivateMode(ModeSafe, "test downgrade", "auto")
if err == nil {
t.Error("expected error on downgrade from LOCKDOWN to SAFE")
}
}
// SP-05: Escalation (SAFE → LOCKDOWN → APOPTOSIS).
func TestPreservation_SP05_Escalation(t *testing.T) {
log := newModeActionLog()
pe := NewPreservationEngine(log.execute)
pe.ActivateMode(ModeSafe, "quorum lost", "auto")
if pe.CurrentMode() != ModeSafe {
t.Fatal("expected SAFE")
}
pe.ActivateMode(ModeLockdown, "compromise detected", "auto")
if pe.CurrentMode() != ModeLockdown {
t.Fatal("expected LOCKDOWN")
}
pe.ActivateMode(ModeApoptosis, "rootkit", "auto")
if pe.CurrentMode() != ModeApoptosis {
t.Fatal("expected APOPTOSIS")
}
}
// SP-06: Safe mode auto-exit.
func TestPreservation_SP06_AutoExit(t *testing.T) {
log := newModeActionLog()
pe := NewPreservationEngine(log.execute)
pe.ActivateMode(ModeSafe, "test", "auto")
// Not yet time.
if pe.ShouldAutoExit() {
t.Error("should not auto-exit immediately")
}
// Fast-forward activation's auto_exit_at.
pe.mu.Lock()
pe.activation.AutoExitAt = time.Now().Add(-1 * time.Second)
pe.mu.Unlock()
if !pe.ShouldAutoExit() {
t.Error("should auto-exit after timer expired")
}
}
// SP-07: Manual deactivation of safe mode.
func TestPreservation_SP07_ManualDeactivate(t *testing.T) {
log := newModeActionLog()
pe := NewPreservationEngine(log.execute)
pe.ActivateMode(ModeSafe, "test", "auto")
err := pe.DeactivateMode("architect:admin")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pe.CurrentMode() != ModeNone {
t.Errorf("expected NONE, got %s", pe.CurrentMode())
}
}
// SP-08: Lockdown deactivation.
func TestPreservation_SP08_LockdownDeactivate(t *testing.T) {
log := newModeActionLog()
pe := NewPreservationEngine(log.execute)
pe.ActivateMode(ModeLockdown, "test", "auto")
err := pe.DeactivateMode("architect:admin")
if err != nil {
t.Fatalf("lockdown deactivation should succeed: %v", err)
}
}
// SP-09: History audit log.
func TestPreservation_SP09_AuditHistory(t *testing.T) {
log := newModeActionLog()
pe := NewPreservationEngine(log.execute)
pe.ActivateMode(ModeSafe, "test", "auto")
pe.DeactivateMode("admin")
history := pe.History()
if len(history) == 0 {
t.Error("expected audit history entries")
}
// Last entry should be deactivation.
last := history[len(history)-1]
if last.Action != "deactivated" {
t.Errorf("expected deactivated, got %s", last.Action)
}
}
// SP-10: Action failure in non-apoptosis mode aborts.
func TestPreservation_SP10_ActionFailure(t *testing.T) {
log := newModeActionLog()
log.failAction = "disable_non_essential_services"
pe := NewPreservationEngine(log.execute)
err := pe.ActivateMode(ModeSafe, "test", "auto")
if err == nil {
t.Error("expected error when safe mode action fails")
}
// Mode should not have changed due to failure.
if pe.CurrentMode() != ModeNone {
t.Errorf("expected NONE after failed activation, got %s", pe.CurrentMode())
}
}
// SP-10b: Action failure in apoptosis mode continues.
func TestPreservation_SP10b_ApoptosisActionFailure(t *testing.T) {
log := newModeActionLog()
log.failAction = "graceful_shutdown"
pe := NewPreservationEngine(log.execute)
// Apoptosis should continue despite action failures.
err := pe.ActivateMode(ModeApoptosis, "rootkit", "auto")
if err != nil {
t.Fatalf("apoptosis should not fail on action errors: %v", err)
}
if pe.CurrentMode() != ModeApoptosis {
t.Errorf("expected APOPTOSIS, got %s", pe.CurrentMode())
}
}
// Test ModeNone activation rejected.
func TestPreservation_ModeNoneRejected(t *testing.T) {
pe := NewPreservationEngine(func(_ EmergencyMode, _ string, _ map[string]interface{}) error { return nil })
err := pe.ActivateMode(ModeNone, "test", "auto")
if err == nil {
t.Error("expected error activating ModeNone")
}
}
// Test deactivate when already NONE.
func TestPreservation_DeactivateNone(t *testing.T) {
pe := NewPreservationEngine(func(_ EmergencyMode, _ string, _ map[string]interface{}) error { return nil })
err := pe.DeactivateMode("admin")
if err != nil {
t.Errorf("deactivating NONE should be no-op: %v", err)
}
}
// Test ShouldAutoExit when not in safe mode.
func TestPreservation_AutoExitNotSafe(t *testing.T) {
pe := NewPreservationEngine(func(_ EmergencyMode, _ string, _ map[string]interface{}) error { return nil })
if pe.ShouldAutoExit() {
t.Error("should not auto-exit when mode is NONE")
}
}
// --- Integrity Verifier Tests ---
// SP-04 (ТЗ): Binary integrity check — hash mismatch.
func TestIntegrity_BinaryMismatch(t *testing.T) {
tmpDir := t.TempDir()
binPath := filepath.Join(tmpDir, "test-binary")
os.WriteFile(binPath, []byte("original content"), 0o644)
// Calculate correct hash.
h := sha256.Sum256([]byte("original content"))
correctHash := hex.EncodeToString(h[:])
iv := NewIntegrityVerifier([]byte("test-key"))
iv.RegisterBinary(binPath, correctHash)
// Verify (should pass).
report := iv.VerifyAll()
if report.Overall != IntegrityVerified {
t.Errorf("expected VERIFIED, got %s", report.Overall)
}
// Tamper with the binary.
os.WriteFile(binPath, []byte("tampered content"), 0o644)
// Verify (should fail).
report = iv.VerifyAll()
if report.Overall != IntegrityCompromised {
t.Errorf("expected COMPROMISED, got %s", report.Overall)
}
bs := report.Binaries[binPath]
if bs.Status != IntegrityCompromised {
t.Errorf("expected binary COMPROMISED, got %s", bs.Status)
}
}
// Binary not found.
func TestIntegrity_BinaryNotFound(t *testing.T) {
iv := NewIntegrityVerifier([]byte("test-key"))
iv.RegisterBinary("/nonexistent/binary", "abc123")
report := iv.VerifyAll()
bs := report.Binaries["/nonexistent/binary"]
if bs.Status != IntegrityUnknown {
t.Errorf("expected UNKNOWN for missing binary, got %s", bs.Status)
}
}
// Config HMAC computation.
func TestIntegrity_ConfigHMAC(t *testing.T) {
tmpDir := t.TempDir()
cfgPath := filepath.Join(tmpDir, "config.yaml")
os.WriteFile(cfgPath, []byte("server:\n port: 8080"), 0o644)
iv := NewIntegrityVerifier([]byte("hmac-key"))
iv.RegisterConfig(cfgPath)
report := iv.VerifyAll()
cs := report.Configs[cfgPath]
if !cs.Valid {
t.Errorf("expected valid config, got error: %s", cs.Error)
}
if cs.CurrentHMAC == "" {
t.Error("expected non-empty HMAC")
}
}
// Config file unreadable.
func TestIntegrity_ConfigUnreadable(t *testing.T) {
iv := NewIntegrityVerifier([]byte("key"))
iv.RegisterConfig("/nonexistent/config.yaml")
report := iv.VerifyAll()
cs := report.Configs["/nonexistent/config.yaml"]
if cs.Valid {
t.Error("expected invalid for unreadable config")
}
}
// Decision chain — file does not exist (OK, no chain yet).
func TestIntegrity_ChainNotExist(t *testing.T) {
iv := NewIntegrityVerifier([]byte("key"))
iv.SetChainPath("/nonexistent/decisions.log")
report := iv.VerifyAll()
if report.Chain == nil {
t.Fatal("expected chain status")
}
if !report.Chain.Valid {
t.Error("nonexistent chain should be valid (no entries)")
}
}
// Decision chain — file exists.
func TestIntegrity_ChainExists(t *testing.T) {
tmpDir := t.TempDir()
chainPath := filepath.Join(tmpDir, "decisions.log")
os.WriteFile(chainPath, []byte("entry1\nentry2\n"), 0o644)
iv := NewIntegrityVerifier([]byte("key"))
iv.SetChainPath(chainPath)
report := iv.VerifyAll()
if report.Chain == nil {
t.Fatal("expected chain status")
}
if !report.Chain.Valid {
t.Error("expected valid chain")
}
}
// LastReport.
func TestIntegrity_LastReport(t *testing.T) {
iv := NewIntegrityVerifier([]byte("key"))
if iv.LastReport() != nil {
t.Error("expected nil before first verify")
}
iv.VerifyAll()
if iv.LastReport() == nil {
t.Error("expected report after verify")
}
}
// Pluggable integrity check in PreservationEngine.
func TestPreservation_IntegrityCheck(t *testing.T) {
pe := NewPreservationEngine(func(_ EmergencyMode, _ string, _ map[string]interface{}) error { return nil })
// Default: no integrity fn → VERIFIED.
report := pe.CheckIntegrity()
if report.Overall != IntegrityVerified {
t.Errorf("expected VERIFIED, got %s", report.Overall)
}
// Set custom checker.
pe.SetIntegrityCheck(func() IntegrityReport {
return IntegrityReport{Overall: IntegrityCompromised, Timestamp: time.Now()}
})
report = pe.CheckIntegrity()
if report.Overall != IntegrityCompromised {
t.Errorf("expected COMPROMISED from custom checker, got %s", report.Overall)
}
}

View file

@ -0,0 +1,398 @@
package resilience
import (
"context"
"fmt"
"log/slog"
"sync"
"time"
)
// PlaybookStatus tracks the state of a running playbook.
type PlaybookStatus string
const (
PlaybookPending PlaybookStatus = "PENDING"
PlaybookRunning PlaybookStatus = "RUNNING"
PlaybookSucceeded PlaybookStatus = "SUCCEEDED"
PlaybookFailed PlaybookStatus = "FAILED"
PlaybookRolledBack PlaybookStatus = "ROLLED_BACK"
)
// PlaybookStep is a single step in a recovery playbook.
type PlaybookStep struct {
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type"` // shell, api, consensus, crypto, systemd, http, prometheus
Timeout time.Duration `json:"timeout"`
Retries int `json:"retries"`
Params map[string]interface{} `json:"params,omitempty"`
OnError string `json:"on_error"` // abort, continue, rollback
Condition string `json:"condition,omitempty"` // prerequisite condition
}
// Playbook defines a complete recovery procedure.
type Playbook struct {
ID string `json:"id"`
Name string `json:"name"`
Version string `json:"version"`
TriggerMetric string `json:"trigger_metric"`
TriggerSeverity string `json:"trigger_severity"`
DiagnosisChecks []PlaybookStep `json:"diagnosis_checks"`
Actions []PlaybookStep `json:"actions"`
RollbackActions []PlaybookStep `json:"rollback_actions"`
SuccessCriteria []string `json:"success_criteria"`
}
// PlaybookExecution tracks a single playbook run.
type PlaybookExecution struct {
ID string `json:"id"`
PlaybookID string `json:"playbook_id"`
Component string `json:"component"`
Status PlaybookStatus `json:"status"`
StartedAt time.Time `json:"started_at"`
CompletedAt time.Time `json:"completed_at,omitempty"`
StepsRun []StepResult `json:"steps_run"`
Error string `json:"error,omitempty"`
}
// StepResult records the execution of a single playbook step.
type StepResult struct {
StepID string `json:"step_id"`
StepName string `json:"step_name"`
Success bool `json:"success"`
Duration time.Duration `json:"duration"`
Output string `json:"output,omitempty"`
Error string `json:"error,omitempty"`
}
// PlaybookExecutorFunc runs a single playbook step.
type PlaybookExecutorFunc func(ctx context.Context, step PlaybookStep, component string) (string, error)
// RecoveryPlaybookEngine manages and executes recovery playbooks.
type RecoveryPlaybookEngine struct {
mu sync.RWMutex
playbooks map[string]*Playbook
executions []*PlaybookExecution
execCount int64
executor PlaybookExecutorFunc
logger *slog.Logger
}
// NewRecoveryPlaybookEngine creates a new playbook engine.
func NewRecoveryPlaybookEngine(executor PlaybookExecutorFunc) *RecoveryPlaybookEngine {
return &RecoveryPlaybookEngine{
playbooks: make(map[string]*Playbook),
executions: make([]*PlaybookExecution, 0),
executor: executor,
logger: slog.Default().With("component", "sarl-recovery-playbooks"),
}
}
// RegisterPlaybook adds a playbook to the engine.
func (rpe *RecoveryPlaybookEngine) RegisterPlaybook(pb Playbook) {
rpe.mu.Lock()
defer rpe.mu.Unlock()
rpe.playbooks[pb.ID] = &pb
rpe.logger.Info("playbook registered", "id", pb.ID, "name", pb.Name)
}
// Execute runs a playbook for a given component. Returns the execution ID.
func (rpe *RecoveryPlaybookEngine) Execute(ctx context.Context, playbookID, component string) (string, error) {
rpe.mu.Lock()
pb, ok := rpe.playbooks[playbookID]
if !ok {
rpe.mu.Unlock()
return "", fmt.Errorf("playbook %s not found", playbookID)
}
rpe.execCount++
exec := &PlaybookExecution{
ID: fmt.Sprintf("exec-%d", rpe.execCount),
PlaybookID: playbookID,
Component: component,
Status: PlaybookRunning,
StartedAt: time.Now(),
StepsRun: make([]StepResult, 0),
}
rpe.executions = append(rpe.executions, exec)
rpe.mu.Unlock()
rpe.logger.Info("playbook execution started",
"exec_id", exec.ID,
"playbook", pb.Name,
"component", component,
)
// Phase 1: Diagnosis checks.
for _, check := range pb.DiagnosisChecks {
result := rpe.runStep(ctx, check, component)
exec.StepsRun = append(exec.StepsRun, result)
if !result.Success {
rpe.logger.Warn("diagnosis check failed",
"step", check.ID,
"error", result.Error,
)
}
}
// Phase 2: Execute recovery actions.
var execErr error
for _, action := range pb.Actions {
result := rpe.runStep(ctx, action, component)
exec.StepsRun = append(exec.StepsRun, result)
if !result.Success {
switch action.OnError {
case "continue":
continue
case "rollback":
execErr = fmt.Errorf("step %s failed (rollback): %s", action.ID, result.Error)
default: // "abort"
execErr = fmt.Errorf("step %s failed: %s", action.ID, result.Error)
}
break
}
}
// Phase 3: Handle result.
if execErr != nil {
rpe.logger.Error("playbook failed, executing rollback",
"exec_id", exec.ID,
"error", execErr,
)
// Execute rollback.
for _, rb := range pb.RollbackActions {
result := rpe.runStep(ctx, rb, component)
exec.StepsRun = append(exec.StepsRun, result)
}
exec.Status = PlaybookRolledBack
exec.Error = execErr.Error()
} else {
exec.Status = PlaybookSucceeded
rpe.logger.Info("playbook succeeded",
"exec_id", exec.ID,
"component", component,
"duration", time.Since(exec.StartedAt),
)
}
exec.CompletedAt = time.Now()
return exec.ID, execErr
}
// runStep executes a single step with timeout and retries.
func (rpe *RecoveryPlaybookEngine) runStep(ctx context.Context, step PlaybookStep, component string) StepResult {
start := time.Now()
result := StepResult{
StepID: step.ID,
StepName: step.Name,
}
retries := step.Retries
if retries <= 0 {
retries = 1
}
var lastErr error
for attempt := 0; attempt < retries; attempt++ {
stepCtx := ctx
var cancel context.CancelFunc
if step.Timeout > 0 {
stepCtx, cancel = context.WithTimeout(ctx, step.Timeout)
}
output, err := rpe.executor(stepCtx, step, component)
if cancel != nil {
cancel()
}
if err == nil {
result.Success = true
result.Output = output
result.Duration = time.Since(start)
return result
}
lastErr = err
if attempt < retries-1 {
rpe.logger.Warn("step retry",
"step", step.ID,
"attempt", attempt+1,
"error", err,
)
}
}
result.Success = false
result.Error = lastErr.Error()
result.Duration = time.Since(start)
return result
}
// GetExecution returns a playbook execution by ID.
// Returns a deep copy to prevent data races with the execution goroutine.
func (rpe *RecoveryPlaybookEngine) GetExecution(id string) (*PlaybookExecution, bool) {
rpe.mu.RLock()
defer rpe.mu.RUnlock()
for _, exec := range rpe.executions {
if exec.ID == id {
cp := *exec
cp.StepsRun = make([]StepResult, len(exec.StepsRun))
copy(cp.StepsRun, exec.StepsRun)
return &cp, true
}
}
return nil, false
}
// RecentExecutions returns the last N executions.
// Returns deep copies to prevent data races with the execution goroutine.
func (rpe *RecoveryPlaybookEngine) RecentExecutions(n int) []PlaybookExecution {
rpe.mu.RLock()
defer rpe.mu.RUnlock()
total := len(rpe.executions)
if total == 0 {
return nil
}
start := total - n
if start < 0 {
start = 0
}
result := make([]PlaybookExecution, 0, n)
for i := start; i < total; i++ {
cp := *rpe.executions[i]
cp.StepsRun = make([]StepResult, len(rpe.executions[i].StepsRun))
copy(cp.StepsRun, rpe.executions[i].StepsRun)
result = append(result, cp)
}
return result
}
// PlaybookCount returns the number of registered playbooks.
func (rpe *RecoveryPlaybookEngine) PlaybookCount() int {
rpe.mu.RLock()
defer rpe.mu.RUnlock()
return len(rpe.playbooks)
}
// --- Built-in playbooks per ТЗ §7.1 ---
// DefaultPlaybooks returns the 3 built-in recovery playbooks.
func DefaultPlaybooks() []Playbook {
return []Playbook{
ComponentResurrectionPlaybook(),
ConsensusRecoveryPlaybook(),
CryptoRotationPlaybook(),
}
}
// ComponentResurrectionPlaybook per ТЗ §7.1.1.
func ComponentResurrectionPlaybook() Playbook {
return Playbook{
ID: "component-resurrection",
Name: "Component Resurrection",
Version: "1.0",
TriggerMetric: "component_offline",
TriggerSeverity: "CRITICAL",
DiagnosisChecks: []PlaybookStep{
{ID: "diag-process", Name: "Check process exists", Type: "shell", Timeout: 5 * time.Second},
{ID: "diag-crashes", Name: "Check recent crashes", Type: "shell", Timeout: 5 * time.Second},
{ID: "diag-resources", Name: "Check resource exhaustion", Type: "prometheus", Timeout: 5 * time.Second},
{ID: "diag-deps", Name: "Check dependency health", Type: "http", Timeout: 10 * time.Second},
},
Actions: []PlaybookStep{
{ID: "capture-forensics", Name: "Capture forensics", Type: "shell", Timeout: 30 * time.Second, OnError: "continue"},
{ID: "clear-resources", Name: "Clear temp resources", Type: "shell", Timeout: 10 * time.Second, OnError: "continue"},
{ID: "restart-component", Name: "Restart component", Type: "systemd", Timeout: 60 * time.Second, OnError: "abort"},
{ID: "verify-health", Name: "Verify health", Type: "http", Timeout: 30 * time.Second, Retries: 3, OnError: "abort"},
{ID: "verify-metrics", Name: "Verify metrics", Type: "prometheus", Timeout: 30 * time.Second, OnError: "continue"},
{ID: "notify-success", Name: "Notify SOC", Type: "api", Timeout: 5 * time.Second, OnError: "continue"},
},
RollbackActions: []PlaybookStep{
{ID: "rb-safe-mode", Name: "Enter safe mode", Type: "api", Timeout: 10 * time.Second},
{ID: "rb-notify", Name: "Notify architect", Type: "api", Timeout: 5 * time.Second},
},
SuccessCriteria: []string{
"component_status == HEALTHY",
"health_check_passed == true",
"no_crashes_for_5min == true",
},
}
}
// ConsensusRecoveryPlaybook per ТЗ §7.1.2.
func ConsensusRecoveryPlaybook() Playbook {
return Playbook{
ID: "consensus-recovery",
Name: "Distributed Consensus Recovery",
Version: "1.0",
TriggerMetric: "split_brain",
TriggerSeverity: "CRITICAL",
DiagnosisChecks: []PlaybookStep{
{ID: "diag-peers", Name: "Check peer connectivity", Type: "api", Timeout: 10 * time.Second},
{ID: "diag-sync", Name: "Check sync status", Type: "api", Timeout: 10 * time.Second},
{ID: "diag-genome", Name: "Verify genome", Type: "api", Timeout: 5 * time.Second},
},
Actions: []PlaybookStep{
{ID: "pause-writes", Name: "Pause all writes", Type: "api", Timeout: 10 * time.Second, OnError: "abort"},
{ID: "elect-leader", Name: "Elect leader (Raft)", Type: "consensus", Timeout: 60 * time.Second, OnError: "abort"},
{ID: "sync-state", Name: "Sync state from leader", Type: "api", Timeout: 300 * time.Second, OnError: "rollback"},
{ID: "verify-consistency", Name: "Verify consistency", Type: "api", Timeout: 60 * time.Second, OnError: "abort"},
{ID: "resume-writes", Name: "Resume writes", Type: "api", Timeout: 10 * time.Second, OnError: "abort"},
{ID: "notify-cluster", Name: "Notify cluster", Type: "api", Timeout: 5 * time.Second, OnError: "continue"},
},
RollbackActions: []PlaybookStep{
{ID: "rb-readonly", Name: "Maintain readonly", Type: "api", Timeout: 10 * time.Second},
{ID: "rb-notify", Name: "Notify architect", Type: "api", Timeout: 5 * time.Second},
},
SuccessCriteria: []string{
"leader_elected == true",
"state_synced == true",
"consistency_verified == true",
"writes_resumed == true",
},
}
}
// CryptoRotationPlaybook per ТЗ §7.1.3.
func CryptoRotationPlaybook() Playbook {
return Playbook{
ID: "crypto-rotation",
Name: "Cryptographic Key Rotation",
Version: "1.0",
TriggerMetric: "key_compromise",
TriggerSeverity: "HIGH",
DiagnosisChecks: []PlaybookStep{
{ID: "diag-key-age", Name: "Check key age", Type: "crypto", Timeout: 5 * time.Second},
{ID: "diag-usage", Name: "Check key usage anomaly", Type: "prometheus", Timeout: 5 * time.Second},
{ID: "diag-tpm", Name: "Check TPM health", Type: "shell", Timeout: 5 * time.Second},
},
Actions: []PlaybookStep{
{ID: "gen-keys", Name: "Generate new keys", Type: "crypto", Timeout: 30 * time.Second, OnError: "abort",
Params: map[string]interface{}{"algorithm": "ECDSA-P256"},
},
{ID: "rotate-certs", Name: "Rotate mTLS certs", Type: "crypto", Timeout: 120 * time.Second, OnError: "rollback"},
{ID: "resign-chain", Name: "Re-sign decision chain", Type: "crypto", Timeout: 300 * time.Second, OnError: "continue"},
{ID: "verify-peers", Name: "Verify peer certs", Type: "api", Timeout: 60 * time.Second, OnError: "abort"},
{ID: "revoke-old", Name: "Revoke old keys", Type: "crypto", Timeout: 30 * time.Second, OnError: "continue"},
{ID: "notify-soc", Name: "Notify SOC", Type: "api", Timeout: 5 * time.Second, OnError: "continue"},
},
RollbackActions: []PlaybookStep{
{ID: "rb-revert-keys", Name: "Revert to previous keys", Type: "crypto", Timeout: 30 * time.Second},
{ID: "rb-notify", Name: "Notify architect", Type: "api", Timeout: 5 * time.Second},
},
SuccessCriteria: []string{
"new_keys_generated == true",
"certs_distributed == true",
"peers_verified == true",
"old_keys_revoked == true",
},
}
}

View file

@ -0,0 +1,318 @@
package resilience
import (
"context"
"fmt"
"testing"
"time"
)
// --- Mock playbook executor ---
type mockPlaybookExecutor struct {
failSteps map[string]bool
callCount int
}
func newMockPlaybookExecutor() *mockPlaybookExecutor {
return &mockPlaybookExecutor{failSteps: make(map[string]bool)}
}
func (m *mockPlaybookExecutor) execute(_ context.Context, step PlaybookStep, _ string) (string, error) {
m.callCount++
if m.failSteps[step.ID] {
return "", fmt.Errorf("step %s failed", step.ID)
}
return fmt.Sprintf("step %s completed", step.ID), nil
}
// --- Recovery Playbook Tests ---
// AR-01: Component resurrection (success).
func TestPlaybook_AR01_ResurrectionSuccess(t *testing.T) {
mock := newMockPlaybookExecutor()
rpe := NewRecoveryPlaybookEngine(mock.execute)
rpe.RegisterPlaybook(ComponentResurrectionPlaybook())
execID, err := rpe.Execute(context.Background(), "component-resurrection", "soc-ingest")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
exec, ok := rpe.GetExecution(execID)
if !ok {
t.Fatal("execution not found")
}
if exec.Status != PlaybookSucceeded {
t.Errorf("expected SUCCEEDED, got %s", exec.Status)
}
if len(exec.StepsRun) == 0 {
t.Error("expected steps to be recorded")
}
}
// AR-02: Component resurrection (failure → rollback).
func TestPlaybook_AR02_ResurrectionFailure(t *testing.T) {
mock := newMockPlaybookExecutor()
mock.failSteps["restart-component"] = true
rpe := NewRecoveryPlaybookEngine(mock.execute)
rpe.RegisterPlaybook(ComponentResurrectionPlaybook())
_, err := rpe.Execute(context.Background(), "component-resurrection", "soc-ingest")
if err == nil {
t.Fatal("expected error")
}
execs := rpe.RecentExecutions(10)
if len(execs) == 0 {
t.Fatal("expected execution")
}
if execs[0].Status != PlaybookRolledBack {
t.Errorf("expected ROLLED_BACK, got %s", execs[0].Status)
}
}
// AR-03: Consensus recovery (success).
func TestPlaybook_AR03_ConsensusSuccess(t *testing.T) {
mock := newMockPlaybookExecutor()
rpe := NewRecoveryPlaybookEngine(mock.execute)
rpe.RegisterPlaybook(ConsensusRecoveryPlaybook())
_, err := rpe.Execute(context.Background(), "consensus-recovery", "cluster")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
// AR-04: Consensus recovery (failure → readonly maintained).
func TestPlaybook_AR04_ConsensusFailure(t *testing.T) {
mock := newMockPlaybookExecutor()
mock.failSteps["elect-leader"] = true
rpe := NewRecoveryPlaybookEngine(mock.execute)
rpe.RegisterPlaybook(ConsensusRecoveryPlaybook())
_, err := rpe.Execute(context.Background(), "consensus-recovery", "cluster")
if err == nil {
t.Fatal("expected error")
}
execs := rpe.RecentExecutions(10)
if execs[0].Status != PlaybookRolledBack {
t.Errorf("expected ROLLED_BACK, got %s", execs[0].Status)
}
}
// AR-05: Crypto key rotation (success).
func TestPlaybook_AR05_CryptoSuccess(t *testing.T) {
mock := newMockPlaybookExecutor()
rpe := NewRecoveryPlaybookEngine(mock.execute)
rpe.RegisterPlaybook(CryptoRotationPlaybook())
_, err := rpe.Execute(context.Background(), "crypto-rotation", "system")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
// AR-06: Crypto rotation (emergency — cert rotation fails → rollback).
func TestPlaybook_AR06_CryptoRollback(t *testing.T) {
mock := newMockPlaybookExecutor()
mock.failSteps["rotate-certs"] = true
rpe := NewRecoveryPlaybookEngine(mock.execute)
rpe.RegisterPlaybook(CryptoRotationPlaybook())
_, err := rpe.Execute(context.Background(), "crypto-rotation", "system")
if err == nil {
t.Fatal("expected error on cert rotation failure")
}
execs := rpe.RecentExecutions(10)
// Should have run rollback (revert keys).
found := false
for _, s := range execs[0].StepsRun {
if s.StepID == "rb-revert-keys" {
found = true
}
}
if !found {
t.Error("expected rollback step rb-revert-keys")
}
}
// AR-07: Forensic capture (all steps recorded).
func TestPlaybook_AR07_ForensicCapture(t *testing.T) {
mock := newMockPlaybookExecutor()
rpe := NewRecoveryPlaybookEngine(mock.execute)
rpe.RegisterPlaybook(ComponentResurrectionPlaybook())
execID, _ := rpe.Execute(context.Background(), "component-resurrection", "comp")
exec, _ := rpe.GetExecution(execID)
for _, step := range exec.StepsRun {
if step.StepID == "" {
t.Error("step missing ID")
}
if step.StepName == "" {
t.Errorf("step %s has empty name", step.StepID)
}
}
}
// AR-08: Rollback execution on action failure.
func TestPlaybook_AR08_RollbackExecution(t *testing.T) {
mock := newMockPlaybookExecutor()
mock.failSteps["sync-state"] = true // Sync fails → rollback trigger.
rpe := NewRecoveryPlaybookEngine(mock.execute)
rpe.RegisterPlaybook(ConsensusRecoveryPlaybook())
rpe.Execute(context.Background(), "consensus-recovery", "cluster")
execs := rpe.RecentExecutions(10)
if execs[0].Status != PlaybookRolledBack {
t.Errorf("expected ROLLED_BACK, got %s", execs[0].Status)
}
}
// AR-09: Step retries.
func TestPlaybook_AR09_StepRetries(t *testing.T) {
callCount := 0
executor := func(_ context.Context, step PlaybookStep, _ string) (string, error) {
callCount++
if step.ID == "verify-health" && callCount <= 2 {
return "", fmt.Errorf("not healthy yet")
}
return "ok", nil
}
rpe := NewRecoveryPlaybookEngine(executor)
rpe.RegisterPlaybook(ComponentResurrectionPlaybook())
_, err := rpe.Execute(context.Background(), "component-resurrection", "comp")
if err != nil {
t.Fatalf("expected success after retries: %v", err)
}
}
// AR-10: Playbook not found.
func TestPlaybook_AR10_NotFound(t *testing.T) {
rpe := NewRecoveryPlaybookEngine(nil)
_, err := rpe.Execute(context.Background(), "nonexistent", "comp")
if err == nil {
t.Fatal("expected error for nonexistent playbook")
}
}
// AR-11: Audit logging (all step timestamps).
func TestPlaybook_AR11_AuditTimestamps(t *testing.T) {
mock := newMockPlaybookExecutor()
rpe := NewRecoveryPlaybookEngine(mock.execute)
rpe.RegisterPlaybook(ComponentResurrectionPlaybook())
execID, _ := rpe.Execute(context.Background(), "component-resurrection", "comp")
exec, _ := rpe.GetExecution(execID)
if exec.StartedAt.IsZero() {
t.Error("missing started_at")
}
if exec.CompletedAt.IsZero() {
t.Error("missing completed_at")
}
}
// AR-12: OnError=continue skips non-critical failures.
func TestPlaybook_AR12_ContinueOnError(t *testing.T) {
mock := newMockPlaybookExecutor()
mock.failSteps["capture-forensics"] = true // OnError=continue.
mock.failSteps["notify-success"] = true // OnError=continue.
rpe := NewRecoveryPlaybookEngine(mock.execute)
rpe.RegisterPlaybook(ComponentResurrectionPlaybook())
_, err := rpe.Execute(context.Background(), "component-resurrection", "comp")
if err != nil {
t.Fatalf("expected success despite continue-on-error steps: %v", err)
}
}
// AR-13: Context cancellation.
func TestPlaybook_AR13_ContextCancel(t *testing.T) {
executor := func(ctx context.Context, _ PlaybookStep, _ string) (string, error) {
select {
case <-ctx.Done():
return "", ctx.Err()
case <-time.After(10 * time.Millisecond):
return "ok", nil
}
}
rpe := NewRecoveryPlaybookEngine(executor)
rpe.RegisterPlaybook(ComponentResurrectionPlaybook())
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately.
_, err := rpe.Execute(ctx, "component-resurrection", "comp")
// May or may not error depending on timing, but should not hang.
_ = err
}
// AR-14: DefaultPlaybooks returns 3.
func TestPlaybook_AR14_DefaultPlaybooks(t *testing.T) {
pbs := DefaultPlaybooks()
if len(pbs) != 3 {
t.Errorf("expected 3 playbooks, got %d", len(pbs))
}
ids := map[string]bool{}
for _, pb := range pbs {
if ids[pb.ID] {
t.Errorf("duplicate playbook ID: %s", pb.ID)
}
ids[pb.ID] = true
if len(pb.Actions) == 0 {
t.Errorf("playbook %s has no actions", pb.ID)
}
if len(pb.SuccessCriteria) == 0 {
t.Errorf("playbook %s has no success criteria", pb.ID)
}
}
}
// AR-15: PlaybookCount and RecentExecutions.
func TestPlaybook_AR15_CountsAndRecent(t *testing.T) {
mock := newMockPlaybookExecutor()
rpe := NewRecoveryPlaybookEngine(mock.execute)
if rpe.PlaybookCount() != 0 {
t.Error("expected 0")
}
for _, pb := range DefaultPlaybooks() {
rpe.RegisterPlaybook(pb)
}
if rpe.PlaybookCount() != 3 {
t.Errorf("expected 3, got %d", rpe.PlaybookCount())
}
// Run two playbooks.
rpe.Execute(context.Background(), "component-resurrection", "comp1")
rpe.Execute(context.Background(), "crypto-rotation", "comp2")
recent := rpe.RecentExecutions(1)
if len(recent) != 1 {
t.Errorf("expected 1 recent, got %d", len(recent))
}
if recent[0].PlaybookID != "crypto-rotation" {
t.Errorf("expected crypto-rotation, got %s", recent[0].PlaybookID)
}
all := rpe.RecentExecutions(100)
if len(all) != 2 {
t.Errorf("expected 2 total, got %d", len(all))
}
}

View file

@ -0,0 +1,278 @@
package shadow_ai
import (
"fmt"
"log/slog"
"sync"
"time"
)
// --- Tiered Approval Workflow ---
// Implements §6 of the ТЗ: data classification → approval tier → SLA tracking.
// ApprovalStatus tracks the state of an approval request.
type ApprovalStatus string
const (
ApprovalPending ApprovalStatus = "pending"
ApprovalApproved ApprovalStatus = "approved"
ApprovalDenied ApprovalStatus = "denied"
ApprovalExpired ApprovalStatus = "expired"
ApprovalAutoApproved ApprovalStatus = "auto_approved"
)
// DefaultApprovalTiers defines the approval requirements per data classification.
func DefaultApprovalTiers() []ApprovalTier {
return []ApprovalTier{
{
Name: "Tier 1: Public Data",
DataClass: DataPublic,
ApprovalNeeded: nil, // Auto-approve
SLA: 0,
AutoApprove: true,
},
{
Name: "Tier 2: Internal Data",
DataClass: DataInternal,
ApprovalNeeded: []string{"manager"},
SLA: 4 * time.Hour,
AutoApprove: false,
},
{
Name: "Tier 3: Confidential Data",
DataClass: DataConfidential,
ApprovalNeeded: []string{"manager", "soc"},
SLA: 24 * time.Hour,
AutoApprove: false,
},
{
Name: "Tier 4: Critical Data",
DataClass: DataCritical,
ApprovalNeeded: []string{"ciso"},
SLA: 0, // Manual only, no auto-expire
AutoApprove: false,
},
}
}
// ApprovalEngine manages the tiered approval workflow.
type ApprovalEngine struct {
mu sync.RWMutex
tiers []ApprovalTier
requests map[string]*ApprovalRequest
logger *slog.Logger
}
// NewApprovalEngine creates an engine with default tiers.
func NewApprovalEngine() *ApprovalEngine {
return &ApprovalEngine{
tiers: DefaultApprovalTiers(),
requests: make(map[string]*ApprovalRequest),
logger: slog.Default().With("component", "shadow-ai-approvals"),
}
}
// SubmitRequest creates a new approval request based on data classification.
// Returns the request or auto-approves if the tier allows it.
func (ae *ApprovalEngine) SubmitRequest(userID, docID string, dataClass DataClassification) *ApprovalRequest {
ae.mu.Lock()
defer ae.mu.Unlock()
tier := ae.findTier(dataClass)
req := &ApprovalRequest{
ID: genApprovalID(),
DocID: docID,
UserID: userID,
Tier: tier.Name,
DataClass: dataClass,
Status: string(ApprovalPending),
CreatedAt: time.Now(),
}
// Set expiry based on SLA.
if tier.SLA > 0 {
req.ExpiresAt = req.CreatedAt.Add(tier.SLA)
}
// Auto-approve for public data.
if tier.AutoApprove {
req.Status = string(ApprovalAutoApproved)
req.ApprovedBy = "system"
req.ResolvedAt = time.Now()
ae.logger.Info("auto-approved",
"request_id", req.ID,
"user", userID,
"data_class", dataClass,
)
} else {
ae.logger.Info("approval required",
"request_id", req.ID,
"user", userID,
"data_class", dataClass,
"tier", tier.Name,
"approvers", tier.ApprovalNeeded,
)
}
ae.requests[req.ID] = req
return req
}
// Approve approves a pending request.
func (ae *ApprovalEngine) Approve(requestID, approvedBy string) error {
ae.mu.Lock()
defer ae.mu.Unlock()
req, ok := ae.requests[requestID]
if !ok {
return fmt.Errorf("request %s not found", requestID)
}
if req.Status != string(ApprovalPending) {
return fmt.Errorf("request %s is not pending (status: %s)", requestID, req.Status)
}
req.Status = string(ApprovalApproved)
req.ApprovedBy = approvedBy
req.ResolvedAt = time.Now()
ae.logger.Info("approved",
"request_id", requestID,
"approved_by", approvedBy,
)
return nil
}
// Deny denies a pending request.
func (ae *ApprovalEngine) Deny(requestID, deniedBy, reason string) error {
ae.mu.Lock()
defer ae.mu.Unlock()
req, ok := ae.requests[requestID]
if !ok {
return fmt.Errorf("request %s not found", requestID)
}
if req.Status != string(ApprovalPending) {
return fmt.Errorf("request %s is not pending (status: %s)", requestID, req.Status)
}
req.Status = string(ApprovalDenied)
req.DeniedBy = deniedBy
req.Reason = reason
req.ResolvedAt = time.Now()
ae.logger.Info("denied",
"request_id", requestID,
"denied_by", deniedBy,
"reason", reason,
)
return nil
}
// GetRequest returns an approval request by ID.
func (ae *ApprovalEngine) GetRequest(requestID string) (*ApprovalRequest, bool) {
ae.mu.RLock()
defer ae.mu.RUnlock()
req, ok := ae.requests[requestID]
if !ok {
return nil, false
}
cp := *req
return &cp, true
}
// PendingRequests returns all pending approval requests.
func (ae *ApprovalEngine) PendingRequests() []ApprovalRequest {
ae.mu.RLock()
defer ae.mu.RUnlock()
var result []ApprovalRequest
for _, req := range ae.requests {
if req.Status == string(ApprovalPending) {
result = append(result, *req)
}
}
return result
}
// ExpireOverdue marks overdue pending requests as expired.
// Returns the number of expired requests.
func (ae *ApprovalEngine) ExpireOverdue() int {
ae.mu.Lock()
defer ae.mu.Unlock()
now := time.Now()
expired := 0
for _, req := range ae.requests {
if req.Status == string(ApprovalPending) && !req.ExpiresAt.IsZero() && now.After(req.ExpiresAt) {
req.Status = string(ApprovalExpired)
req.ResolvedAt = now
expired++
ae.logger.Warn("request expired",
"request_id", req.ID,
"user", req.UserID,
"expired_at", req.ExpiresAt,
)
}
}
return expired
}
// Stats returns approval workflow statistics.
func (ae *ApprovalEngine) Stats() map[string]int {
ae.mu.RLock()
defer ae.mu.RUnlock()
stats := map[string]int{
"total": len(ae.requests),
"pending": 0,
"approved": 0,
"denied": 0,
"expired": 0,
"auto_approved": 0,
}
for _, req := range ae.requests {
switch ApprovalStatus(req.Status) {
case ApprovalPending:
stats["pending"]++
case ApprovalApproved:
stats["approved"]++
case ApprovalDenied:
stats["denied"]++
case ApprovalExpired:
stats["expired"]++
case ApprovalAutoApproved:
stats["auto_approved"]++
}
}
return stats
}
// Tiers returns the approval tier configuration.
func (ae *ApprovalEngine) Tiers() []ApprovalTier {
return ae.tiers
}
func (ae *ApprovalEngine) findTier(dataClass DataClassification) ApprovalTier {
for _, t := range ae.tiers {
if t.DataClass == dataClass {
return t
}
}
// Default to most restrictive.
return ae.tiers[len(ae.tiers)-1]
}
var approvalCounter uint64
var approvalCounterMu sync.Mutex
func genApprovalID() string {
approvalCounterMu.Lock()
approvalCounter++
id := approvalCounter
approvalCounterMu.Unlock()
return fmt.Sprintf("apr-%d-%d", time.Now().UnixMilli(), id)
}

View file

@ -0,0 +1,116 @@
package shadow_ai
import (
"time"
domsoc "github.com/syntrex/gomcp/internal/domain/soc"
)
// ShadowAICorrelationRules returns SOC correlation rules specific to Shadow AI
// detection. These integrate into the existing SOC correlation engine.
func ShadowAICorrelationRules() []domsoc.SOCCorrelationRule {
return []domsoc.SOCCorrelationRule{
{
ID: "SAI-CR-001",
Name: "Multi-Service Shadow AI",
RequiredCategories: []string{"shadow_ai_usage"},
MinEvents: 3,
TimeWindow: 10 * time.Minute,
Severity: domsoc.SeverityHigh,
KillChainPhase: "Reconnaissance",
MITREMapping: []string{"T1595"},
Description: "User accessing 3+ distinct AI services within 10 minutes. Indicates active AI tool exploration or data shopping across providers.",
},
{
ID: "SAI-CR-002",
Name: "Shadow AI + Data Exfiltration",
RequiredCategories: []string{"shadow_ai_usage", "exfiltration"},
MinEvents: 2,
TimeWindow: 15 * time.Minute,
Severity: domsoc.SeverityCritical,
KillChainPhase: "Exfiltration",
MITREMapping: []string{"T1041", "T1567"},
Description: "Shadow AI usage followed by data exfiltration attempt. Possible corporate data leakage via unauthorized AI services.",
},
{
ID: "SAI-CR-003",
Name: "Shadow AI Volume Spike",
RequiredCategories: []string{"shadow_ai_usage"},
MinEvents: 10,
TimeWindow: 1 * time.Hour,
Severity: domsoc.SeverityHigh,
KillChainPhase: "Actions on Objectives",
MITREMapping: []string{"T1048"},
Description: "10+ shadow AI events from same source within 1 hour. Indicates bulk data transfer to external AI service.",
},
{
ID: "SAI-CR-004",
Name: "Shadow AI After Hours",
RequiredCategories: []string{"shadow_ai_usage"},
MinEvents: 2,
TimeWindow: 30 * time.Minute,
Severity: domsoc.SeverityMedium,
KillChainPhase: "Persistence",
MITREMapping: []string{"T1053"},
Description: "Shadow AI usage outside business hours (detected via timestamp clustering). May indicate automated scripts or insider threat.",
},
{
ID: "SAI-CR-005",
Name: "Integration Failure Chain",
RequiredCategories: []string{"integration_health"},
MinEvents: 3,
TimeWindow: 5 * time.Minute,
Severity: domsoc.SeverityCritical,
KillChainPhase: "Defense Evasion",
MITREMapping: []string{"T1562"},
Description: "3+ integration health failures in 5 minutes. Possible attack on enforcement infrastructure to blind Shadow AI detection.",
},
{
ID: "SAI-CR-006",
Name: "Shadow AI + PII Leak",
RequiredCategories: []string{"shadow_ai_usage", "pii_leak"},
MinEvents: 2,
TimeWindow: 10 * time.Minute,
Severity: domsoc.SeverityCritical,
KillChainPhase: "Exfiltration",
MITREMapping: []string{"T1567.002"},
Description: "Shadow AI usage combined with PII leak detection. GDPR/regulatory violation in progress — immediate response required.",
},
{
ID: "SAI-CR-007",
Name: "Shadow AI Evasion Attempt",
SequenceCategories: []string{"shadow_ai_usage", "evasion"},
MinEvents: 2,
TimeWindow: 10 * time.Minute,
Severity: domsoc.SeverityHigh,
KillChainPhase: "Defense Evasion",
MITREMapping: []string{"T1090", "T1573"},
Description: "Shadow AI usage followed by evasion technique (VPN, proxy chaining, encoding). User attempting to bypass detection.",
},
{
ID: "SAI-CR-008",
Name: "Cross-Department AI Usage",
RequiredCategories: []string{"shadow_ai_usage"},
MinEvents: 5,
TimeWindow: 30 * time.Minute,
Severity: domsoc.SeverityMedium,
CrossSensor: true,
KillChainPhase: "Lateral Movement",
MITREMapping: []string{"T1021"},
Description: "Shadow AI events from 5+ distinct network segments/sensors within 30 minutes. Indicates coordinated policy circumvention or compromised credentials used across departments.",
},
// Severity trend: escalating shadow AI event severity
{
ID: "SAI-CR-009",
Name: "Shadow AI Escalation",
SeverityTrend: "ascending",
TrendCategory: "shadow_ai_usage",
MinEvents: 3,
TimeWindow: 30 * time.Minute,
Severity: domsoc.SeverityCritical,
KillChainPhase: "Exploitation",
MITREMapping: []string{"T1059"},
Description: "Ascending severity pattern in Shadow AI events: user escalating from casual browsing to bulk data uploads. Crescendo data theft in progress.",
},
}
}

View file

@ -0,0 +1,503 @@
package shadow_ai
import (
"log/slog"
"math"
"regexp"
"sort"
"strings"
"sync"
"time"
)
// --- AI Signature Database ---
// AISignatureDB contains known AI service signatures for detection.
type AISignatureDB struct {
mu sync.RWMutex
services []AIServiceInfo
domainPatterns []*domainPattern
apiKeyPatterns []*APIKeyPattern
httpSignatures []string
}
type domainPattern struct {
original string
regex *regexp.Regexp
service string
}
// APIKeyPattern defines a regex pattern for detecting AI API keys.
type APIKeyPattern struct {
Name string `json:"name"`
Pattern *regexp.Regexp `json:"-"`
Entropy float64 `json:"min_entropy"`
}
// NewAISignatureDB creates a signature database pre-loaded with known AI services.
func NewAISignatureDB() *AISignatureDB {
db := &AISignatureDB{}
db.loadDefaults()
return db
}
// loadDefaults populates the database with known AI services and patterns.
func (db *AISignatureDB) loadDefaults() {
db.services = defaultAIServices()
// Compile domain patterns.
for _, svc := range db.services {
for _, d := range svc.Domains {
pattern := domainToRegex(d)
db.domainPatterns = append(db.domainPatterns, &domainPattern{
original: d,
regex: pattern,
service: svc.Name,
})
}
}
// API key patterns.
db.apiKeyPatterns = defaultAPIKeyPatterns()
// HTTP header signatures.
db.httpSignatures = []string{
"authorization: bearer sk-", // OpenAI
"authorization: bearer ant-", // Anthropic
"x-api-key: sk-ant-", // Anthropic v2
"x-goog-api-key:", // Google AI
"authorization: bearer gsk_", // Groq
"authorization: bearer hf_", // HuggingFace
}
}
// domainToRegex converts a wildcard domain (e.g., "*.openai.com") to a regex.
func domainToRegex(domain string) *regexp.Regexp {
escaped := regexp.QuoteMeta(domain)
escaped = strings.ReplaceAll(escaped, `\*`, `[a-zA-Z0-9\-]+`)
return regexp.MustCompile("(?i)^" + escaped + "$")
}
// MatchDomain checks if a domain matches any known AI service.
// Returns the service name or empty string.
func (db *AISignatureDB) MatchDomain(domain string) string {
db.mu.RLock()
defer db.mu.RUnlock()
domain = strings.ToLower(strings.TrimSpace(domain))
for _, dp := range db.domainPatterns {
if dp.regex.MatchString(domain) {
return dp.service
}
}
return ""
}
// MatchHTTPHeaders checks if HTTP headers contain known AI service signatures.
func (db *AISignatureDB) MatchHTTPHeaders(headers map[string]string) string {
db.mu.RLock()
defer db.mu.RUnlock()
for key, value := range headers {
headerLine := strings.ToLower(key + ": " + value)
for _, sig := range db.httpSignatures {
if strings.Contains(headerLine, sig) {
return sig
}
}
}
return ""
}
// ScanForAPIKeys scans content for AI API keys.
// Returns the matched pattern name or empty string.
func (db *AISignatureDB) ScanForAPIKeys(content string) string {
db.mu.RLock()
defer db.mu.RUnlock()
for _, pattern := range db.apiKeyPatterns {
if pattern.Pattern.MatchString(content) {
return pattern.Name
}
}
return ""
}
// ServiceCount returns the number of known AI services.
func (db *AISignatureDB) ServiceCount() int {
db.mu.RLock()
defer db.mu.RUnlock()
return len(db.services)
}
// DomainPatternCount returns the number of compiled domain patterns.
func (db *AISignatureDB) DomainPatternCount() int {
db.mu.RLock()
defer db.mu.RUnlock()
return len(db.domainPatterns)
}
// AddService adds a custom AI service to the database.
func (db *AISignatureDB) AddService(svc AIServiceInfo) {
db.mu.Lock()
defer db.mu.Unlock()
db.services = append(db.services, svc)
for _, d := range svc.Domains {
pattern := domainToRegex(d)
db.domainPatterns = append(db.domainPatterns, &domainPattern{
original: d,
regex: pattern,
service: svc.Name,
})
}
}
// --- Network Detector ---
// NetworkEvent represents a network connection event for analysis.
type NetworkEvent struct {
User string `json:"user"`
Hostname string `json:"hostname"`
Destination string `json:"destination"` // Domain or IP
Port int `json:"port"`
HTTPHeaders map[string]string `json:"http_headers,omitempty"`
TLSJA3 string `json:"tls_ja3,omitempty"`
DataSize int64 `json:"data_size"`
Timestamp time.Time `json:"timestamp"`
}
// NetworkDetector analyzes network events for AI service access.
type NetworkDetector struct {
signatures *AISignatureDB
logger *slog.Logger
}
// NewNetworkDetector creates a new network detector with the default signature DB.
func NewNetworkDetector() *NetworkDetector {
return &NetworkDetector{
signatures: NewAISignatureDB(),
logger: slog.Default().With("component", "shadow-ai-network"),
}
}
// NewNetworkDetectorWithDB creates a detector with a custom signature database.
func NewNetworkDetectorWithDB(db *AISignatureDB) *NetworkDetector {
return &NetworkDetector{
signatures: db,
logger: slog.Default().With("component", "shadow-ai-network"),
}
}
// Analyze checks a network event for AI service access.
// Returns a ShadowAIEvent if detected, nil otherwise.
func (nd *NetworkDetector) Analyze(event NetworkEvent) *ShadowAIEvent {
// Check domain match.
if service := nd.signatures.MatchDomain(event.Destination); service != "" {
nd.logger.Info("AI domain detected",
"user", event.User,
"destination", event.Destination,
"service", service,
)
return &ShadowAIEvent{
UserID: event.User,
Hostname: event.Hostname,
Destination: event.Destination,
AIService: service,
DetectionMethod: DetectNetwork,
Action: "detected",
DataSize: event.DataSize,
Timestamp: event.Timestamp,
}
}
// Check HTTP header signatures.
if sig := nd.signatures.MatchHTTPHeaders(event.HTTPHeaders); sig != "" {
nd.logger.Info("AI HTTP signature detected",
"user", event.User,
"destination", event.Destination,
"signature", sig,
)
return &ShadowAIEvent{
UserID: event.User,
Hostname: event.Hostname,
Destination: event.Destination,
AIService: "unknown",
DetectionMethod: DetectHTTP,
Action: "detected",
DataSize: event.DataSize,
Timestamp: event.Timestamp,
Metadata: map[string]string{"http_signature": sig},
}
}
return nil
}
// SignatureDB returns the underlying signature database for extension.
func (nd *NetworkDetector) SignatureDB() *AISignatureDB {
return nd.signatures
}
// --- Behavioral Detector ---
// UserBehaviorProfile tracks a user's AI access behavior for anomaly detection.
type UserBehaviorProfile struct {
UserID string `json:"user_id"`
AccessFrequency float64 `json:"access_frequency"` // Requests per hour
DataVolumePerHour float64 `json:"data_volume_per_hour"` // Bytes per hour
KnownDestinations []string `json:"known_destinations"`
UpdatedAt time.Time `json:"updated_at"`
}
// BehavioralAlert is emitted when anomalous AI access is detected.
type BehavioralAlert struct {
UserID string `json:"user_id"`
AnomalyType string `json:"anomaly_type"` // "access_spike", "new_destination", "data_volume_spike"
Current float64 `json:"current"`
Baseline float64 `json:"baseline"`
ZScore float64 `json:"z_score"`
Destination string `json:"destination,omitempty"`
Severity string `json:"severity"`
}
// BehavioralDetector detects anomalous AI usage patterns per user.
type BehavioralDetector struct {
mu sync.RWMutex
baselines map[string]*UserBehaviorProfile
current map[string]*UserBehaviorProfile
alertBus chan BehavioralAlert
logger *slog.Logger
}
// NewBehavioralDetector creates a behavioral detector with a buffered alert bus.
func NewBehavioralDetector(alertBufSize int) *BehavioralDetector {
if alertBufSize <= 0 {
alertBufSize = 100
}
return &BehavioralDetector{
baselines: make(map[string]*UserBehaviorProfile),
current: make(map[string]*UserBehaviorProfile),
alertBus: make(chan BehavioralAlert, alertBufSize),
logger: slog.Default().With("component", "shadow-ai-behavioral"),
}
}
// RecordAccess records a single AI access attempt for behavioral tracking.
func (bd *BehavioralDetector) RecordAccess(userID, destination string, dataSize int64) {
bd.mu.Lock()
defer bd.mu.Unlock()
profile, ok := bd.current[userID]
if !ok {
profile = &UserBehaviorProfile{
UserID: userID,
}
bd.current[userID] = profile
}
profile.AccessFrequency++
profile.DataVolumePerHour += float64(dataSize)
profile.UpdatedAt = time.Now()
// Track destinations.
found := false
for _, d := range profile.KnownDestinations {
if d == destination {
found = true
break
}
}
if !found {
profile.KnownDestinations = append(profile.KnownDestinations, destination)
}
}
// SetBaseline sets the known baseline behavior for a user.
func (bd *BehavioralDetector) SetBaseline(userID string, profile *UserBehaviorProfile) {
bd.mu.Lock()
defer bd.mu.Unlock()
bd.baselines[userID] = profile
}
// DetectAnomalies compares current behavior to baselines and emits alerts.
func (bd *BehavioralDetector) DetectAnomalies() []BehavioralAlert {
bd.mu.RLock()
defer bd.mu.RUnlock()
var alerts []BehavioralAlert
for userID, current := range bd.current {
baseline, ok := bd.baselines[userID]
if !ok {
// No baseline — any AI access from this user is suspicious.
if current.AccessFrequency > 0 {
alert := BehavioralAlert{
UserID: userID,
AnomalyType: "first_ai_access",
Current: current.AccessFrequency,
Baseline: 0,
Severity: "WARNING",
}
alerts = append(alerts, alert)
bd.emitAlert(alert)
}
continue
}
// Z-score for access frequency.
if baseline.AccessFrequency > 0 {
zscore := (current.AccessFrequency - baseline.AccessFrequency) / math.Max(baseline.AccessFrequency*0.3, 1)
if math.Abs(zscore) > 3.0 {
alert := BehavioralAlert{
UserID: userID,
AnomalyType: "access_spike",
Current: current.AccessFrequency,
Baseline: baseline.AccessFrequency,
ZScore: zscore,
Severity: "WARNING",
}
alerts = append(alerts, alert)
bd.emitAlert(alert)
}
}
// Detect new AI destinations.
for _, dest := range current.KnownDestinations {
isNew := true
for _, known := range baseline.KnownDestinations {
if dest == known {
isNew = false
break
}
}
if isNew {
alert := BehavioralAlert{
UserID: userID,
AnomalyType: "new_ai_destination",
Destination: dest,
Severity: "HIGH",
}
alerts = append(alerts, alert)
bd.emitAlert(alert)
}
}
// Z-score for data volume.
if baseline.DataVolumePerHour > 0 {
zscore := (current.DataVolumePerHour - baseline.DataVolumePerHour) / math.Max(baseline.DataVolumePerHour*0.3, 1)
if math.Abs(zscore) > 3.0 {
alert := BehavioralAlert{
UserID: userID,
AnomalyType: "data_volume_spike",
Current: current.DataVolumePerHour,
Baseline: baseline.DataVolumePerHour,
ZScore: zscore,
Severity: "CRITICAL",
}
alerts = append(alerts, alert)
bd.emitAlert(alert)
}
}
}
return alerts
}
// Alerts returns the alert channel for consuming behavioral alerts.
func (bd *BehavioralDetector) Alerts() <-chan BehavioralAlert {
return bd.alertBus
}
// ResetCurrent clears the current period data (call after each analysis window).
func (bd *BehavioralDetector) ResetCurrent() {
bd.mu.Lock()
defer bd.mu.Unlock()
bd.current = make(map[string]*UserBehaviorProfile)
}
func (bd *BehavioralDetector) emitAlert(alert BehavioralAlert) {
select {
case bd.alertBus <- alert:
default:
bd.logger.Warn("behavioral alert bus full, dropping alert",
"user", alert.UserID,
"type", alert.AnomalyType,
)
}
}
// --- Default Data ---
func defaultAIServices() []AIServiceInfo {
return []AIServiceInfo{
{Name: "ChatGPT", Vendor: "OpenAI", Domains: []string{"chat.openai.com", "api.openai.com", "*.openai.com"}, Category: "llm"},
{Name: "Claude", Vendor: "Anthropic", Domains: []string{"claude.ai", "api.anthropic.com", "*.anthropic.com"}, Category: "llm"},
{Name: "Gemini", Vendor: "Google", Domains: []string{"gemini.google.com", "generativelanguage.googleapis.com", "aistudio.google.com"}, Category: "llm"},
{Name: "Copilot", Vendor: "Microsoft", Domains: []string{"copilot.microsoft.com", "*.copilot.microsoft.com"}, Category: "code_assist"},
{Name: "Cohere", Vendor: "Cohere", Domains: []string{"api.cohere.ai", "dashboard.cohere.com", "*.cohere.ai"}, Category: "llm"},
{Name: "AI21", Vendor: "AI21 Labs", Domains: []string{"api.ai21.com", "studio.ai21.com", "*.ai21.com"}, Category: "llm"},
{Name: "HuggingFace", Vendor: "Hugging Face", Domains: []string{"api-inference.huggingface.co", "huggingface.co", "*.huggingface.co"}, Category: "llm"},
{Name: "Replicate", Vendor: "Replicate", Domains: []string{"api.replicate.com", "replicate.com", "*.replicate.com"}, Category: "llm"},
{Name: "Mistral", Vendor: "Mistral AI", Domains: []string{"api.mistral.ai", "chat.mistral.ai", "*.mistral.ai"}, Category: "llm"},
{Name: "Perplexity", Vendor: "Perplexity", Domains: []string{"api.perplexity.ai", "perplexity.ai", "*.perplexity.ai"}, Category: "llm"},
{Name: "Groq", Vendor: "Groq", Domains: []string{"api.groq.com", "groq.com", "*.groq.com"}, Category: "llm"},
{Name: "Together", Vendor: "Together AI", Domains: []string{"api.together.xyz", "together.ai", "*.together.ai"}, Category: "llm"},
{Name: "Stability", Vendor: "Stability AI", Domains: []string{"api.stability.ai", "*.stability.ai"}, Category: "image_gen"},
{Name: "Midjourney", Vendor: "Midjourney", Domains: []string{"midjourney.com", "*.midjourney.com"}, Category: "image_gen"},
{Name: "DALL-E", Vendor: "OpenAI", Domains: []string{"labs.openai.com"}, Category: "image_gen"},
{Name: "Cursor", Vendor: "Cursor", Domains: []string{"api2.cursor.sh", "*.cursor.sh"}, Category: "code_assist"},
{Name: "Replit AI", Vendor: "Replit", Domains: []string{"replit.com", "*.replit.com"}, Category: "code_assist"},
{Name: "Codeium", Vendor: "Codeium", Domains: []string{"*.codeium.com", "codeium.com"}, Category: "code_assist"},
{Name: "Tabnine", Vendor: "Tabnine", Domains: []string{"*.tabnine.com", "tabnine.com"}, Category: "code_assist"},
{Name: "Qwen", Vendor: "Alibaba", Domains: []string{"dashscope.aliyuncs.com", "*.dashscope.aliyuncs.com"}, Category: "llm"},
{Name: "DeepSeek", Vendor: "DeepSeek", Domains: []string{"api.deepseek.com", "chat.deepseek.com", "*.deepseek.com"}, Category: "llm"},
{Name: "Kimi", Vendor: "Moonshot AI", Domains: []string{"api.moonshot.cn", "kimi.moonshot.cn", "*.moonshot.cn"}, Category: "llm"},
{Name: "Baidu ERNIE", Vendor: "Baidu", Domains: []string{"aip.baidubce.com", "erniebot.baidu.com"}, Category: "llm"},
{Name: "Jasper", Vendor: "Jasper", Domains: []string{"app.jasper.ai", "api.jasper.ai", "*.jasper.ai"}, Category: "llm"},
{Name: "Writer", Vendor: "Writer", Domains: []string{"writer.com", "api.writer.com", "*.writer.com"}, Category: "llm"},
{Name: "Notion AI", Vendor: "Notion", Domains: []string{"www.notion.so"}, Category: "productivity"},
{Name: "Grammarly AI", Vendor: "Grammarly", Domains: []string{"*.grammarly.com"}, Category: "productivity"},
{Name: "Runway", Vendor: "Runway", Domains: []string{"app.runwayml.com", "api.runwayml.com", "*.runwayml.com"}, Category: "video_gen"},
{Name: "Pika", Vendor: "Pika", Domains: []string{"pika.art", "*.pika.art"}, Category: "video_gen"},
{Name: "ElevenLabs", Vendor: "ElevenLabs", Domains: []string{"api.elevenlabs.io", "elevenlabs.io", "*.elevenlabs.io"}, Category: "audio_gen"},
{Name: "Suno", Vendor: "Suno", Domains: []string{"suno.com", "*.suno.com"}, Category: "audio_gen"},
{Name: "OpenRouter", Vendor: "OpenRouter", Domains: []string{"openrouter.ai", "*.openrouter.ai"}, Category: "llm"},
{Name: "Scale AI", Vendor: "Scale", Domains: []string{"scale.com", "api.scale.com", "*.scale.com"}, Category: "llm"},
{Name: "Inflection Pi", Vendor: "Inflection", Domains: []string{"pi.ai", "api.inflection.ai"}, Category: "llm"},
{Name: "Grok", Vendor: "xAI", Domains: []string{"grok.x.ai", "api.x.ai"}, Category: "llm"},
{Name: "Character.AI", Vendor: "Character.AI", Domains: []string{"character.ai", "*.character.ai"}, Category: "llm"},
{Name: "Poe", Vendor: "Quora", Domains: []string{"poe.com", "*.poe.com"}, Category: "llm"},
{Name: "You.com", Vendor: "You.com", Domains: []string{"you.com", "api.you.com"}, Category: "llm"},
{Name: "Phind", Vendor: "Phind", Domains: []string{"phind.com", "*.phind.com"}, Category: "llm"},
}
}
func defaultAPIKeyPatterns() []*APIKeyPattern {
return []*APIKeyPattern{
{Name: "OpenAI API Key", Pattern: regexp.MustCompile(`sk-[a-zA-Z0-9]{20,}T3BlbkFJ[a-zA-Z0-9]{20,}`), Entropy: 4.5},
{Name: "OpenAI Project Key", Pattern: regexp.MustCompile(`sk-proj-[a-zA-Z0-9\-_]{48,}`), Entropy: 4.5},
{Name: "Anthropic API Key", Pattern: regexp.MustCompile(`sk-ant-[a-zA-Z0-9\-_]{90,}`), Entropy: 4.5},
{Name: "Google AI API Key", Pattern: regexp.MustCompile(`AIza[0-9A-Za-z\-_]{35}`), Entropy: 4.0},
{Name: "HuggingFace Token", Pattern: regexp.MustCompile(`hf_[a-zA-Z0-9]{34}`), Entropy: 4.5},
{Name: "Groq API Key", Pattern: regexp.MustCompile(`gsk_[a-zA-Z0-9]{52}`), Entropy: 4.5},
{Name: "Cohere API Key", Pattern: regexp.MustCompile(`[a-zA-Z0-9]{10,}-[a-zA-Z0-9]{4,}-[a-zA-Z0-9]{4,}-[a-zA-Z0-9]{4,}-[a-zA-Z0-9]{12,}`), Entropy: 4.5},
{Name: "Replicate API Token", Pattern: regexp.MustCompile(`r8_[a-zA-Z0-9]{37}`), Entropy: 4.5},
}
}
// ServicesByCategory returns AI services grouped by category.
func ServicesByCategory() map[string][]AIServiceInfo {
services := defaultAIServices()
result := make(map[string][]AIServiceInfo)
for _, svc := range services {
result[svc.Category] = append(result[svc.Category], svc)
}
// Sort each category by name for deterministic output.
for cat := range result {
sort.Slice(result[cat], func(i, j int) bool {
return result[cat][i].Name < result[cat][j].Name
})
}
return result
}

View file

@ -0,0 +1,353 @@
package shadow_ai
import (
"crypto/sha256"
"fmt"
"regexp"
"strings"
"sync"
"time"
)
// --- Document Review Bridge ---
// Controlled gateway for AI access: scans documents for secrets and PII,
// supports content redaction, and routes through the approval workflow.
// DocReviewStatus tracks the lifecycle of a document review.
type DocReviewStatus string
const (
DocReviewPending DocReviewStatus = "pending"
DocReviewScanning DocReviewStatus = "scanning"
DocReviewClean DocReviewStatus = "clean"
DocReviewRedacted DocReviewStatus = "redacted"
DocReviewBlocked DocReviewStatus = "blocked"
DocReviewApproved DocReviewStatus = "approved"
)
// ScanResult contains the results of scanning a document.
type ScanResult struct {
DocumentID string `json:"document_id"`
Status DocReviewStatus `json:"status"`
PIIFound []PIIMatch `json:"pii_found,omitempty"`
SecretsFound []SecretMatch `json:"secrets_found,omitempty"`
DataClass DataClassification `json:"data_classification"`
ContentHash string `json:"content_hash"`
ScannedAt time.Time `json:"scanned_at"`
SizeBytes int `json:"size_bytes"`
}
// PIIMatch represents a detected PII pattern in content.
type PIIMatch struct {
Type string `json:"type"` // "email", "phone", "ssn", "credit_card", "passport"
Location int `json:"location"` // Character offset
Length int `json:"length"`
Masked string `json:"masked"` // Redacted value, e.g., "j***@example.com"
}
// SecretMatch represents a detected secret/API key in content.
type SecretMatch struct {
Type string `json:"type"` // "api_key", "password", "token", "private_key"
Location int `json:"location"`
Length int `json:"length"`
Provider string `json:"provider"` // "OpenAI", "AWS", "GitHub", etc.
}
// DocBridge manages document scanning, redaction, and review workflow.
type DocBridge struct {
mu sync.RWMutex
reviews map[string]*ScanResult
piiPatterns []*piiPattern
secretPats []secretPattern // Cached compiled patterns
signatures *AISignatureDB // Reused across scans
maxDocSize int // bytes
}
type piiPattern struct {
name string
regex *regexp.Regexp
maskFn func(string) string
}
// NewDocBridge creates a new Document Review Bridge.
func NewDocBridge() *DocBridge {
return &DocBridge{
reviews: make(map[string]*ScanResult),
piiPatterns: defaultPIIPatterns(),
secretPats: secretPatterns(),
signatures: NewAISignatureDB(),
maxDocSize: 10 * 1024 * 1024, // 10 MB
}
}
// ScanDocument scans content for PII and secrets, classifies data, returns result.
func (db *DocBridge) ScanDocument(docID, content, userID string) *ScanResult {
result := &ScanResult{
DocumentID: docID,
Status: DocReviewScanning,
ScannedAt: time.Now(),
SizeBytes: len(content),
}
// Content hash for dedup.
h := sha256.Sum256([]byte(content))
result.ContentHash = fmt.Sprintf("%x", h[:])
// Size check.
if len(content) > db.maxDocSize {
result.Status = DocReviewBlocked
result.DataClass = DataCritical
db.store(result)
return result
}
// Scan for PII.
result.PIIFound = db.scanPII(content)
// Scan for secrets (reuse cached signature DB).
if keyType := db.signatures.ScanForAPIKeys(content); keyType != "" {
result.SecretsFound = append(result.SecretsFound, SecretMatch{
Type: "api_key",
Provider: keyType,
})
}
// Scan for additional secret patterns.
result.SecretsFound = append(result.SecretsFound, db.scanSecrets(content)...)
// Classify data based on findings.
result.DataClass = db.classifyData(result)
// Set status based on findings.
if len(result.SecretsFound) > 0 {
result.Status = DocReviewBlocked
} else if len(result.PIIFound) > 0 {
result.Status = DocReviewRedacted
} else {
result.Status = DocReviewClean
}
db.store(result)
return result
}
// RedactContent replaces PII and secrets in content with masked values.
func (db *DocBridge) RedactContent(content string) string {
for _, p := range db.piiPatterns {
content = p.regex.ReplaceAllStringFunc(content, p.maskFn)
}
// Redact common secret patterns (cached).
for _, sp := range db.secretPats {
content = sp.regex.ReplaceAllString(content, sp.replacement)
}
return content
}
// GetReview returns a scan result by document ID.
func (db *DocBridge) GetReview(docID string) (*ScanResult, bool) {
db.mu.RLock()
defer db.mu.RUnlock()
r, ok := db.reviews[docID]
if !ok {
return nil, false
}
cp := *r
return &cp, true
}
// RecentReviews returns the N most recent reviews.
func (db *DocBridge) RecentReviews(limit int) []ScanResult {
db.mu.RLock()
defer db.mu.RUnlock()
results := make([]ScanResult, 0, len(db.reviews))
for _, r := range db.reviews {
results = append(results, *r)
}
// Sort by time desc (simple bubble for bounded set).
for i := 0; i < len(results); i++ {
for j := i + 1; j < len(results); j++ {
if results[j].ScannedAt.After(results[i].ScannedAt) {
results[i], results[j] = results[j], results[i]
}
}
}
if len(results) > limit {
results = results[:limit]
}
return results
}
// Stats returns aggregate document review statistics.
func (db *DocBridge) Stats() map[string]int {
db.mu.RLock()
defer db.mu.RUnlock()
stats := map[string]int{
"total": len(db.reviews),
"clean": 0,
"redacted": 0,
"blocked": 0,
}
for _, r := range db.reviews {
switch r.Status {
case DocReviewClean:
stats["clean"]++
case DocReviewRedacted:
stats["redacted"]++
case DocReviewBlocked:
stats["blocked"]++
}
}
return stats
}
func (db *DocBridge) store(result *ScanResult) {
db.mu.Lock()
defer db.mu.Unlock()
db.reviews[result.DocumentID] = result
}
// scanPII runs all PII patterns against content.
func (db *DocBridge) scanPII(content string) []PIIMatch {
var matches []PIIMatch
for _, p := range db.piiPatterns {
locs := p.regex.FindAllStringIndex(content, -1)
for _, loc := range locs {
matched := content[loc[0]:loc[1]]
matches = append(matches, PIIMatch{
Type: p.name,
Location: loc[0],
Length: loc[1] - loc[0],
Masked: p.maskFn(matched),
})
}
}
return matches
}
// scanSecrets scans for common secret patterns beyond AI API keys.
func (db *DocBridge) scanSecrets(content string) []SecretMatch {
var matches []SecretMatch
for _, sp := range db.secretPats {
locs := sp.regex.FindAllStringIndex(content, -1)
for _, loc := range locs {
matches = append(matches, SecretMatch{
Type: sp.secretType,
Location: loc[0],
Length: loc[1] - loc[0],
Provider: sp.provider,
})
}
}
return matches
}
// classifyData determines the data classification level based on scan results.
func (db *DocBridge) classifyData(result *ScanResult) DataClassification {
if len(result.SecretsFound) > 0 {
return DataCritical
}
hasSensitivePII := false
for _, pii := range result.PIIFound {
switch pii.Type {
case "ssn", "credit_card", "passport":
return DataCritical
case "email", "phone":
hasSensitivePII = true
}
}
if hasSensitivePII {
return DataConfidential
}
if result.SizeBytes > 1024*1024 { // >1MB
return DataInternal
}
return DataPublic
}
// --- PII Patterns ---
func defaultPIIPatterns() []*piiPattern {
return []*piiPattern{
{
name: "email",
regex: regexp.MustCompile(`[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}`),
maskFn: func(s string) string {
parts := strings.SplitN(s, "@", 2)
if len(parts) != 2 {
return "***@***"
}
if len(parts[0]) <= 1 {
return "*@" + parts[1]
}
return string(parts[0][0]) + "***@" + parts[1]
},
},
{
name: "phone",
regex: regexp.MustCompile(`\+?[1-9]\d{0,2}[\s\-]?\(?\d{3}\)?[\s\-]?\d{3}[\s\-]?\d{2,4}`),
maskFn: func(s string) string {
if len(s) < 4 {
return "***"
}
return s[:2] + strings.Repeat("*", len(s)-4) + s[len(s)-2:]
},
},
{
name: "ssn",
regex: regexp.MustCompile(`\b\d{3}-\d{2}-\d{4}\b`),
maskFn: func(_ string) string {
return "***-**-****"
},
},
{
name: "credit_card",
regex: regexp.MustCompile(`\b(?:\d{4}[\s\-]?){3}\d{4}\b`),
maskFn: func(s string) string {
clean := strings.ReplaceAll(strings.ReplaceAll(s, "-", ""), " ", "")
if len(clean) < 4 {
return "****"
}
return strings.Repeat("*", len(clean)-4) + clean[len(clean)-4:]
},
},
{
name: "passport",
regex: regexp.MustCompile(`\b[A-Z]{1,2}\d{6,9}\b`),
maskFn: func(s string) string {
if len(s) <= 2 {
return "**"
}
return s[:2] + strings.Repeat("*", len(s)-2)
},
},
}
}
type secretPattern struct {
secretType string
provider string
regex *regexp.Regexp
replacement string
}
func secretPatterns() []secretPattern {
return []secretPattern{
{secretType: "aws_key", provider: "AWS", regex: regexp.MustCompile(`AKIA[0-9A-Z]{16}`), replacement: "[AWS_KEY_REDACTED]"},
{secretType: "github_token", provider: "GitHub", regex: regexp.MustCompile(`ghp_[a-zA-Z0-9]{36}`), replacement: "[GITHUB_TOKEN_REDACTED]"},
{secretType: "github_token", provider: "GitHub", regex: regexp.MustCompile(`github_pat_[a-zA-Z0-9_]{82}`), replacement: "[GITHUB_PAT_REDACTED]"},
{secretType: "slack_token", provider: "Slack", regex: regexp.MustCompile(`xoxb-[0-9]{10,13}-[0-9]{10,13}-[a-zA-Z0-9]{24}`), replacement: "[SLACK_TOKEN_REDACTED]"},
{secretType: "private_key", provider: "Generic", regex: regexp.MustCompile(`-----BEGIN (?:RSA |EC |DSA )?PRIVATE KEY-----`), replacement: "[PRIVATE_KEY_REDACTED]"},
{secretType: "password", provider: "Generic", regex: regexp.MustCompile(`(?i)password\s*[=:]\s*['"]?[^\s'"]{8,}`), replacement: "[PASSWORD_REDACTED]"},
{secretType: "connection_string", provider: "Database", regex: regexp.MustCompile(`(?i)(?:mysql|postgres|mongodb)://[^\s]+`), replacement: "[DB_CONN_REDACTED]"},
}
}

View file

@ -0,0 +1,148 @@
package shadow_ai
import (
"context"
"fmt"
"log/slog"
"time"
)
// FallbackManager provides priority-based enforcement with graceful degradation.
// Tries enforcement points in priority order; falls back to detect_only if all are offline.
type FallbackManager struct {
registry *PluginRegistry
priority []PluginType // e.g., ["proxy", "firewall", "edr"]
strategy string // "detect_only" | "alert_only"
logger *slog.Logger
// Event logging for detect-only fallback.
eventLogFn func(event ShadowAIEvent)
}
// NewFallbackManager creates a new fallback manager with the given enforcement priority.
func NewFallbackManager(registry *PluginRegistry, strategy string) *FallbackManager {
if strategy == "" {
strategy = "detect_only"
}
return &FallbackManager{
registry: registry,
priority: []PluginType{PluginTypeProxy, PluginTypeFirewall, PluginTypeEDR},
strategy: strategy,
logger: slog.Default().With("component", "shadow-ai-fallback"),
}
}
// SetEventLogger sets the callback for logging detection-only events.
func (fm *FallbackManager) SetEventLogger(fn func(ShadowAIEvent)) {
fm.eventLogFn = fn
}
// BlockDomain attempts to block a domain using the highest-priority healthy plugin.
// Returns the vendor that enforced, or falls back to detect_only mode.
func (fm *FallbackManager) BlockDomain(ctx context.Context, domain, reason string) (enforcedBy string, err error) {
for _, pType := range fm.priority {
plugins := fm.registry.GetByType(pType)
for _, plugin := range plugins {
ne, ok := plugin.(NetworkEnforcer)
if !ok {
// Try WebGateway for URL-based blocking.
if wg, ok := plugin.(WebGateway); ok {
vendor := wg.Vendor()
if !fm.registry.IsHealthy(vendor) {
continue
}
if err := wg.BlockURL(ctx, domain, reason); err != nil {
fm.logger.Warn("block failed on gateway", "vendor", vendor, "error", err)
continue
}
return vendor, nil
}
continue
}
vendor := ne.Vendor()
if !fm.registry.IsHealthy(vendor) {
continue
}
if err := ne.BlockDomain(ctx, domain, reason); err != nil {
fm.logger.Warn("block failed on enforcer", "vendor", vendor, "error", err)
continue
}
return vendor, nil
}
}
// All enforcement points unavailable — fallback.
fm.logger.Warn("all enforcement points unavailable, falling to detect_only",
"domain", domain,
"strategy", fm.strategy,
)
fm.logDetectOnly(domain, reason)
return "", nil
}
// BlockIP attempts to block an IP using the highest-priority healthy firewall.
func (fm *FallbackManager) BlockIP(ctx context.Context, ip string, duration time.Duration, reason string) (enforcedBy string, err error) {
enforcers := fm.registry.GetNetworkEnforcers()
for _, ne := range enforcers {
vendor := ne.Vendor()
if !fm.registry.IsHealthy(vendor) {
continue
}
if err := ne.BlockIP(ctx, ip, duration, reason); err != nil {
fm.logger.Warn("block IP failed", "vendor", vendor, "error", err)
continue
}
return vendor, nil
}
fm.logger.Warn("no healthy enforcer for IP block, falling to detect_only",
"ip", ip,
"strategy", fm.strategy,
)
fm.logDetectOnly(ip, reason)
return "", nil
}
// IsolateHost attempts to isolate a host using the highest-priority healthy EDR.
func (fm *FallbackManager) IsolateHost(ctx context.Context, hostname string) (enforcedBy string, err error) {
controllers := fm.registry.GetEndpointControllers()
for _, ec := range controllers {
vendor := ec.Vendor()
if !fm.registry.IsHealthy(vendor) {
continue
}
if err := ec.IsolateHost(ctx, hostname); err != nil {
fm.logger.Warn("isolate failed", "vendor", vendor, "error", err)
continue
}
return vendor, nil
}
fm.logger.Warn("no healthy EDR for host isolation, falling to detect_only",
"hostname", hostname,
"strategy", fm.strategy,
)
return "", fmt.Errorf("no healthy EDR available for host isolation")
}
// logDetectOnly records a detection-only event when no enforcement is possible.
func (fm *FallbackManager) logDetectOnly(target, reason string) {
if fm.eventLogFn != nil {
fm.eventLogFn(ShadowAIEvent{
Destination: target,
DetectionMethod: DetectNetwork,
Action: "detect_only",
Metadata: map[string]string{
"reason": reason,
"fallback_strategy": fm.strategy,
},
Timestamp: time.Now(),
})
}
}
// Strategy returns the configured fallback strategy.
func (fm *FallbackManager) Strategy() string {
return fm.strategy
}

View file

@ -0,0 +1,163 @@
package shadow_ai
import (
"context"
"fmt"
"log/slog"
"sync"
"time"
)
// PluginStatus represents a plugin's operational state.
type PluginStatus string
const (
PluginStatusHealthy PluginStatus = "healthy"
PluginStatusDegraded PluginStatus = "degraded"
PluginStatusOffline PluginStatus = "offline"
)
// PluginHealth tracks the health state of a single plugin.
type PluginHealth struct {
Vendor string `json:"vendor"`
Type PluginType `json:"type"`
Status PluginStatus `json:"status"`
LastCheck time.Time `json:"last_check"`
Consecutive int `json:"consecutive_failures"`
Latency time.Duration `json:"latency"`
LastError string `json:"last_error,omitempty"`
}
// MaxConsecutivePluginFailures before marking offline.
const MaxConsecutivePluginFailures = 3
// HealthChecker performs continuous health monitoring of all registered plugins.
type HealthChecker struct {
mu sync.RWMutex
registry *PluginRegistry
interval time.Duration
alertFn func(vendor string, status PluginStatus, msg string)
logger *slog.Logger
}
// NewHealthChecker creates a health checker that monitors plugin health.
func NewHealthChecker(registry *PluginRegistry, interval time.Duration, alertFn func(string, PluginStatus, string)) *HealthChecker {
if interval <= 0 {
interval = 30 * time.Second
}
return &HealthChecker{
registry: registry,
interval: interval,
alertFn: alertFn,
logger: slog.Default().With("component", "shadow-ai-health"),
}
}
// Start begins continuous health monitoring. Blocks until ctx is cancelled.
func (hc *HealthChecker) Start(ctx context.Context) {
hc.logger.Info("health checker started", "interval", hc.interval)
ticker := time.NewTicker(hc.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
hc.logger.Info("health checker stopped")
return
case <-ticker.C:
hc.checkAllPlugins(ctx)
}
}
}
// checkAllPlugins runs health checks on all registered plugins.
func (hc *HealthChecker) checkAllPlugins(ctx context.Context) {
vendors := hc.registry.Vendors()
for _, vendor := range vendors {
plugin, ok := hc.registry.Get(vendor)
if !ok {
continue
}
existing, _ := hc.registry.GetHealth(vendor)
if existing == nil {
continue
}
start := time.Now()
err := hc.checkPlugin(ctx, plugin)
latency := time.Since(start)
health := &PluginHealth{
Vendor: vendor,
Type: existing.Type,
LastCheck: time.Now(),
Latency: latency,
}
if err != nil {
health.Consecutive = existing.Consecutive + 1
health.LastError = err.Error()
if health.Consecutive >= MaxConsecutivePluginFailures {
health.Status = PluginStatusOffline
if existing.Status != PluginStatusOffline {
hc.logger.Error("plugin went OFFLINE",
"vendor", vendor,
"consecutive", health.Consecutive,
"error", err,
)
if hc.alertFn != nil {
hc.alertFn(vendor, PluginStatusOffline,
fmt.Sprintf("Plugin %s offline after %d consecutive failures: %v",
vendor, health.Consecutive, err))
}
}
} else {
health.Status = PluginStatusDegraded
hc.logger.Warn("plugin health check failed",
"vendor", vendor,
"consecutive", health.Consecutive,
"error", err,
)
}
} else {
health.Status = PluginStatusHealthy
health.Consecutive = 0
// Log recovery if previously degraded/offline.
if existing.Status != PluginStatusHealthy {
hc.logger.Info("plugin recovered", "vendor", vendor, "latency", latency)
if hc.alertFn != nil {
hc.alertFn(vendor, PluginStatusHealthy,
fmt.Sprintf("Plugin %s recovered, latency %s", vendor, latency))
}
}
}
hc.registry.SetHealth(vendor, health)
}
}
// checkPlugin runs the health check for a single plugin.
func (hc *HealthChecker) checkPlugin(ctx context.Context, plugin interface{}) error {
checkCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
switch p := plugin.(type) {
case NetworkEnforcer:
return p.HealthCheck(checkCtx)
case EndpointController:
return p.HealthCheck(checkCtx)
case WebGateway:
return p.HealthCheck(checkCtx)
default:
return fmt.Errorf("plugin does not implement HealthCheck")
}
}
// CheckNow runs an immediate health check on all plugins (non-blocking).
func (hc *HealthChecker) CheckNow(ctx context.Context) {
hc.checkAllPlugins(ctx)
}

View file

@ -0,0 +1,225 @@
// Package shadow_ai implements the Sentinel Shadow AI Control Module.
//
// Five levels of shadow AI management:
//
// L1 — Universal Integration Layer: plugin-based enforcement (firewall, EDR, proxy)
// L2 — Detection Engine: network signatures, endpoint, API keys, behavioral
// L3 — Document Review Bridge: controlled LLM access with PII/secret scanning
// L4 — Approval Workflow: tiered data classification and manager/SOC approval
// L5 — SOC Integration: dashboard, correlation rules, playbooks, compliance
package shadow_ai
import (
"context"
"time"
)
// --- Plugin Interfaces ---
// NetworkEnforcer is the universal interface for ALL firewalls.
// Implementations: Check Point, Cisco ASA/FMC, Palo Alto, Fortinet.
type NetworkEnforcer interface {
// BlockIP blocks an IP address for the given duration.
BlockIP(ctx context.Context, ip string, duration time.Duration, reason string) error
// BlockDomain blocks a domain name.
BlockDomain(ctx context.Context, domain string, reason string) error
// UnblockIP removes an IP block.
UnblockIP(ctx context.Context, ip string) error
// UnblockDomain removes a domain block.
UnblockDomain(ctx context.Context, domain string) error
// HealthCheck verifies the firewall API is reachable.
HealthCheck(ctx context.Context) error
// Vendor returns the vendor identifier (e.g., "checkpoint", "cisco", "paloalto").
Vendor() string
}
// EndpointController is the universal interface for ALL EDR systems.
// Implementations: CrowdStrike, SentinelOne, Microsoft Defender.
type EndpointController interface {
// IsolateHost quarantines a host from the network.
IsolateHost(ctx context.Context, hostname string) error
// ReleaseHost removes host isolation.
ReleaseHost(ctx context.Context, hostname string) error
// KillProcess terminates a process on a remote host.
KillProcess(ctx context.Context, hostname string, pid int) error
// QuarantineFile moves a file to quarantine on a remote host.
QuarantineFile(ctx context.Context, hostname string, path string) error
// HealthCheck verifies the EDR API is reachable.
HealthCheck(ctx context.Context) error
// Vendor returns the vendor identifier (e.g., "crowdstrike", "sentinelone", "defender").
Vendor() string
}
// WebGateway is the universal interface for ALL proxy/CASB systems.
// Implementations: Zscaler, Netskope, Squid, BlueCoat.
type WebGateway interface {
// BlockURL adds a URL to the blocklist.
BlockURL(ctx context.Context, url string, reason string) error
// UnblockURL removes a URL from the blocklist.
UnblockURL(ctx context.Context, url string) error
// BlockCategory blocks an entire URL category (e.g., "Artificial Intelligence").
BlockCategory(ctx context.Context, category string) error
// HealthCheck verifies the gateway API is reachable.
HealthCheck(ctx context.Context) error
// Vendor returns the vendor identifier (e.g., "zscaler", "netskope", "squid").
Vendor() string
}
// Initializer is implemented by plugins that need configuration before use.
type Initializer interface {
Initialize(config map[string]interface{}) error
}
// --- Plugin Configuration ---
// PluginType categorizes enforcement points.
type PluginType string
const (
PluginTypeFirewall PluginType = "firewall"
PluginTypeEDR PluginType = "edr"
PluginTypeProxy PluginType = "proxy"
PluginTypeDNS PluginType = "dns"
)
// PluginConfig defines a vendor plugin configuration loaded from YAML.
type PluginConfig struct {
Type PluginType `yaml:"type" json:"type"`
Vendor string `yaml:"vendor" json:"vendor"`
Enabled bool `yaml:"enabled" json:"enabled"`
Config map[string]interface{} `yaml:"config" json:"config"`
}
// IntegrationConfig is the top-level Shadow AI configuration.
type IntegrationConfig struct {
Plugins []PluginConfig `yaml:"plugins" json:"plugins"`
FallbackStrategy string `yaml:"fallback_strategy" json:"fallback_strategy"` // "detect_only" | "alert_only"
HealthCheckInterval time.Duration `yaml:"health_check_interval" json:"health_check_interval"` // default: 30s
}
// --- Domain Types ---
// DetectionMethod identifies how a shadow AI usage was detected.
type DetectionMethod string
const (
DetectNetwork DetectionMethod = "network" // Domain/IP match
DetectHTTP DetectionMethod = "http" // HTTP header signature
DetectTLS DetectionMethod = "tls" // TLS/JA3 fingerprint
DetectProcess DetectionMethod = "process" // AI tool process execution
DetectAPIKey DetectionMethod = "api_key" // AI API key in payload
DetectBehavioral DetectionMethod = "behavioral" // Anomalous AI access pattern
DetectClipboard DetectionMethod = "clipboard" // Large clipboard → AI browser pattern
)
// DataClassification determines the approval tier required.
type DataClassification string
const (
DataPublic DataClassification = "PUBLIC"
DataInternal DataClassification = "INTERNAL"
DataConfidential DataClassification = "CONFIDENTIAL"
DataCritical DataClassification = "CRITICAL"
)
// ShadowAIEvent is a detected shadow AI usage attempt.
type ShadowAIEvent struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Hostname string `json:"hostname"`
Destination string `json:"destination"` // Target AI service domain/IP
AIService string `json:"ai_service"` // "chatgpt", "claude", "gemini", etc.
DetectionMethod DetectionMethod `json:"detection_method"`
Action string `json:"action"` // "blocked", "allowed", "pending"
EnforcedBy string `json:"enforced_by"` // Plugin vendor that enforced
DataSize int64 `json:"data_size"` // Bytes sent to AI
Timestamp time.Time `json:"timestamp"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// AIServiceInfo describes a known AI service for signature matching.
type AIServiceInfo struct {
Name string `json:"name"` // "ChatGPT", "Claude", "Gemini"
Vendor string `json:"vendor"` // "OpenAI", "Anthropic", "Google"
Domains []string `json:"domains"` // ["*.openai.com", "chat.openai.com"]
Category string `json:"category"` // "llm", "image_gen", "code_assist"
}
// BlockRequest is an API request to manually block a target.
type BlockRequest struct {
TargetType string `json:"target_type"` // "ip", "domain", "user"
Target string `json:"target"`
Duration time.Duration `json:"duration"`
Reason string `json:"reason"`
BlockedBy string `json:"blocked_by"` // RBAC user
}
// ShadowAIStats provides aggregate statistics for the dashboard.
type ShadowAIStats struct {
TimeRange string `json:"time_range"` // "24h", "7d", "30d"
Total int `json:"total_attempts"`
Blocked int `json:"blocked"`
Approved int `json:"approved"`
Pending int `json:"pending"`
ByService map[string]int `json:"by_service"`
ByDepartment map[string]int `json:"by_department"`
TopViolators []Violator `json:"top_violators"`
}
// Violator tracks a user's shadow AI violation count.
type Violator struct {
UserID string `json:"user_id"`
Attempts int `json:"attempts"`
}
// ApprovalTier defines the approval requirements for a data classification level.
type ApprovalTier struct {
Name string `yaml:"name" json:"name"`
DataClass DataClassification `yaml:"data_class" json:"data_class"`
ApprovalNeeded []string `yaml:"approval_needed" json:"approval_needed"` // ["manager"], ["manager", "soc"], ["ciso"]
SLA time.Duration `yaml:"sla" json:"sla"`
AutoApprove bool `yaml:"auto_approve" json:"auto_approve"`
}
// ApprovalRequest tracks a pending approval for AI access.
type ApprovalRequest struct {
ID string `json:"id"`
DocID string `json:"doc_id"`
UserID string `json:"user_id"`
Tier string `json:"tier"`
DataClass DataClassification `json:"data_class"`
Status string `json:"status"` // "pending", "approved", "denied", "expired"
ApprovedBy string `json:"approved_by,omitempty"`
DeniedBy string `json:"denied_by,omitempty"`
Reason string `json:"reason,omitempty"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
ResolvedAt time.Time `json:"resolved_at,omitempty"`
}
// ComplianceReport is the Shadow AI compliance report for GDPR/SOC2/EU AI Act.
type ComplianceReport struct {
GeneratedAt time.Time `json:"generated_at"`
Period string `json:"period"` // "monthly", "quarterly"
TotalInteractions int `json:"total_interactions"`
BlockedAttempts int `json:"blocked_attempts"`
ApprovedReviews int `json:"approved_reviews"`
PIIDetected int `json:"pii_detected"`
SecretsDetected int `json:"secrets_detected"`
AuditComplete bool `json:"audit_complete"`
Regulations []string `json:"regulations"` // ["GDPR", "SOC2", "EU AI Act"]
}

View file

@ -0,0 +1,212 @@
package shadow_ai
import (
"context"
"fmt"
"log/slog"
"time"
)
// --- Vendor Plugin Stubs ---
// Reference implementations for major security vendors.
// These stubs implement the full interface with logging but no real API calls.
// Production deployments replace these with real vendor SDK integrations.
// CheckPointEnforcer is a stub implementation for Check Point firewalls.
type CheckPointEnforcer struct {
apiURL string
apiKey string
logger *slog.Logger
}
func NewCheckPointEnforcer() *CheckPointEnforcer {
return &CheckPointEnforcer{
logger: slog.Default().With("component", "shadow-ai-plugin-checkpoint"),
}
}
func (c *CheckPointEnforcer) Initialize(config map[string]interface{}) error {
if url, ok := config["api_url"].(string); ok {
c.apiURL = url
}
if key, ok := config["api_key"].(string); ok {
c.apiKey = key
}
if c.apiURL == "" {
return fmt.Errorf("checkpoint: api_url required")
}
c.logger.Info("initialized", "api_url", c.apiURL)
return nil
}
func (c *CheckPointEnforcer) BlockIP(_ context.Context, ip string, duration time.Duration, reason string) error {
c.logger.Info("block IP", "ip", ip, "duration", duration, "reason", reason)
// Stub: would call Check Point Management API POST /web_api/add-host
return nil
}
func (c *CheckPointEnforcer) BlockDomain(_ context.Context, domain string, reason string) error {
c.logger.Info("block domain", "domain", domain, "reason", reason)
// Stub: would create application-site-category block rule
return nil
}
func (c *CheckPointEnforcer) UnblockIP(_ context.Context, ip string) error {
c.logger.Info("unblock IP", "ip", ip)
return nil
}
func (c *CheckPointEnforcer) UnblockDomain(_ context.Context, domain string) error {
c.logger.Info("unblock domain", "domain", domain)
return nil
}
func (c *CheckPointEnforcer) HealthCheck(ctx context.Context) error {
if c.apiURL == "" {
return fmt.Errorf("not configured")
}
// Stub: would call GET /web_api/show-session
return nil
}
func (c *CheckPointEnforcer) Vendor() string { return "checkpoint" }
// CrowdStrikeController is a stub implementation for CrowdStrike Falcon EDR.
type CrowdStrikeController struct {
clientID string
clientSecret string
baseURL string
logger *slog.Logger
}
func NewCrowdStrikeController() *CrowdStrikeController {
return &CrowdStrikeController{
baseURL: "https://api.crowdstrike.com",
logger: slog.Default().With("component", "shadow-ai-plugin-crowdstrike"),
}
}
func (cs *CrowdStrikeController) Initialize(config map[string]interface{}) error {
if id, ok := config["client_id"].(string); ok {
cs.clientID = id
}
if secret, ok := config["client_secret"].(string); ok {
cs.clientSecret = secret
}
if url, ok := config["base_url"].(string); ok {
cs.baseURL = url
}
if cs.clientID == "" {
return fmt.Errorf("crowdstrike: client_id required")
}
cs.logger.Info("initialized", "base_url", cs.baseURL)
return nil
}
func (cs *CrowdStrikeController) IsolateHost(_ context.Context, hostname string) error {
cs.logger.Info("isolate host", "hostname", hostname)
// Stub: would call POST /devices/entities/devices-actions/v2?action_name=contain
return nil
}
func (cs *CrowdStrikeController) ReleaseHost(_ context.Context, hostname string) error {
cs.logger.Info("release host", "hostname", hostname)
// Stub: would call POST /devices/entities/devices-actions/v2?action_name=lift_containment
return nil
}
func (cs *CrowdStrikeController) KillProcess(_ context.Context, hostname string, pid int) error {
cs.logger.Info("kill process", "hostname", hostname, "pid", pid)
// Stub: would use RTR session to kill process
return nil
}
func (cs *CrowdStrikeController) QuarantineFile(_ context.Context, hostname, path string) error {
cs.logger.Info("quarantine file", "hostname", hostname, "path", path)
return nil
}
func (cs *CrowdStrikeController) HealthCheck(ctx context.Context) error {
if cs.clientID == "" {
return fmt.Errorf("not configured")
}
// Stub: would call GET /sensors/queries/sensors/v1?limit=1
return nil
}
func (cs *CrowdStrikeController) Vendor() string { return "crowdstrike" }
// ZscalerGateway is a stub implementation for Zscaler Internet Access.
type ZscalerGateway struct {
cloudName string
apiKey string
username string
password string
logger *slog.Logger
}
func NewZscalerGateway() *ZscalerGateway {
return &ZscalerGateway{
logger: slog.Default().With("component", "shadow-ai-plugin-zscaler"),
}
}
func (z *ZscalerGateway) Initialize(config map[string]interface{}) error {
if cloud, ok := config["cloud_name"].(string); ok {
z.cloudName = cloud
}
if key, ok := config["api_key"].(string); ok {
z.apiKey = key
}
if user, ok := config["username"].(string); ok {
z.username = user
}
if pass, ok := config["password"].(string); ok {
z.password = pass
}
if z.cloudName == "" {
return fmt.Errorf("zscaler: cloud_name required")
}
z.logger.Info("initialized", "cloud", z.cloudName)
return nil
}
func (z *ZscalerGateway) BlockURL(_ context.Context, url, reason string) error {
z.logger.Info("block URL", "url", url, "reason", reason)
// Stub: would call PUT /webApplicationRules to add URL to block list
return nil
}
func (z *ZscalerGateway) UnblockURL(_ context.Context, url string) error {
z.logger.Info("unblock URL", "url", url)
return nil
}
func (z *ZscalerGateway) BlockCategory(_ context.Context, category string) error {
z.logger.Info("block category", "category", category)
// Stub: would update URL category policy to BLOCK
return nil
}
func (z *ZscalerGateway) HealthCheck(ctx context.Context) error {
if z.cloudName == "" {
return fmt.Errorf("not configured")
}
// Stub: would call GET /status
return nil
}
func (z *ZscalerGateway) Vendor() string { return "zscaler" }
// RegisterDefaultPlugins registers all built-in vendor plugin factories.
func RegisterDefaultPlugins(registry *PluginRegistry) {
registry.RegisterFactory(PluginTypeFirewall, "checkpoint", func() interface{} {
return NewCheckPointEnforcer()
})
registry.RegisterFactory(PluginTypeEDR, "crowdstrike", func() interface{} {
return NewCrowdStrikeController()
})
registry.RegisterFactory(PluginTypeProxy, "zscaler", func() interface{} {
return NewZscalerGateway()
})
}

View file

@ -0,0 +1,212 @@
package shadow_ai
import (
"fmt"
"log/slog"
"sync"
)
// PluginFactory creates a new plugin instance.
type PluginFactory func() interface{}
// PluginRegistry manages vendor plugin registration, loading, and lifecycle.
// Thread-safe via sync.RWMutex.
type PluginRegistry struct {
mu sync.RWMutex
plugins map[string]interface{} // vendor → plugin instance
factories map[string]PluginFactory // "type_vendor" → factory
configs map[string]*PluginConfig // vendor → config
health map[string]*PluginHealth // vendor → health status
logger *slog.Logger
}
// NewPluginRegistry creates a new plugin registry.
func NewPluginRegistry() *PluginRegistry {
return &PluginRegistry{
plugins: make(map[string]interface{}),
factories: make(map[string]PluginFactory),
configs: make(map[string]*PluginConfig),
health: make(map[string]*PluginHealth),
logger: slog.Default().With("component", "shadow-ai-registry"),
}
}
// RegisterFactory registers a plugin factory for a given type+vendor combination.
// Example: RegisterFactory("firewall", "checkpoint", func() interface{} { return &CheckPointEnforcer{} })
func (r *PluginRegistry) RegisterFactory(pluginType PluginType, vendor string, factory PluginFactory) {
r.mu.Lock()
defer r.mu.Unlock()
key := fmt.Sprintf("%s_%s", pluginType, vendor)
r.factories[key] = factory
r.logger.Info("factory registered", "type", pluginType, "vendor", vendor)
}
// LoadPlugins creates and initializes plugins from configuration.
// Plugins that fail to initialize are logged but do not block other plugins.
func (r *PluginRegistry) LoadPlugins(config *IntegrationConfig) error {
r.mu.Lock()
defer r.mu.Unlock()
loaded := 0
for i := range config.Plugins {
pluginCfg := &config.Plugins[i]
if !pluginCfg.Enabled {
r.logger.Debug("plugin disabled, skipping", "vendor", pluginCfg.Vendor)
continue
}
key := fmt.Sprintf("%s_%s", pluginCfg.Type, pluginCfg.Vendor)
factory, exists := r.factories[key]
if !exists {
r.logger.Warn("no factory for plugin", "key", key, "vendor", pluginCfg.Vendor)
continue
}
plugin := factory()
// Initialize if plugin supports it.
if init, ok := plugin.(Initializer); ok {
if err := init.Initialize(pluginCfg.Config); err != nil {
r.logger.Error("plugin init failed", "vendor", pluginCfg.Vendor, "error", err)
continue
}
}
r.plugins[pluginCfg.Vendor] = plugin
r.configs[pluginCfg.Vendor] = pluginCfg
r.health[pluginCfg.Vendor] = &PluginHealth{
Vendor: pluginCfg.Vendor,
Type: pluginCfg.Type,
Status: PluginStatusHealthy,
}
loaded++
r.logger.Info("plugin loaded", "vendor", pluginCfg.Vendor, "type", pluginCfg.Type)
}
r.logger.Info("plugin loading complete", "loaded", loaded, "total", len(config.Plugins))
return nil
}
// Get returns a plugin by vendor name.
func (r *PluginRegistry) Get(vendor string) (interface{}, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
p, ok := r.plugins[vendor]
return p, ok
}
// GetByType returns all plugins of a given type.
func (r *PluginRegistry) GetByType(pluginType PluginType) []interface{} {
r.mu.RLock()
defer r.mu.RUnlock()
var result []interface{}
for vendor, cfg := range r.configs {
if cfg.Type == pluginType {
if plugin, ok := r.plugins[vendor]; ok {
result = append(result, plugin)
}
}
}
return result
}
// GetNetworkEnforcers returns all loaded NetworkEnforcer plugins.
func (r *PluginRegistry) GetNetworkEnforcers() []NetworkEnforcer {
r.mu.RLock()
defer r.mu.RUnlock()
var result []NetworkEnforcer
for _, plugin := range r.plugins {
if ne, ok := plugin.(NetworkEnforcer); ok {
result = append(result, ne)
}
}
return result
}
// GetEndpointControllers returns all loaded EndpointController plugins.
func (r *PluginRegistry) GetEndpointControllers() []EndpointController {
r.mu.RLock()
defer r.mu.RUnlock()
var result []EndpointController
for _, plugin := range r.plugins {
if ec, ok := plugin.(EndpointController); ok {
result = append(result, ec)
}
}
return result
}
// GetWebGateways returns all loaded WebGateway plugins.
func (r *PluginRegistry) GetWebGateways() []WebGateway {
r.mu.RLock()
defer r.mu.RUnlock()
var result []WebGateway
for _, plugin := range r.plugins {
if wg, ok := plugin.(WebGateway); ok {
result = append(result, wg)
}
}
return result
}
// IsHealthy returns true if a plugin is currently healthy.
func (r *PluginRegistry) IsHealthy(vendor string) bool {
r.mu.RLock()
defer r.mu.RUnlock()
h, ok := r.health[vendor]
return ok && h.Status == PluginStatusHealthy
}
// SetHealth updates the health status for a plugin.
func (r *PluginRegistry) SetHealth(vendor string, health *PluginHealth) {
r.mu.Lock()
defer r.mu.Unlock()
r.health[vendor] = health
}
// GetHealth returns the health status snapshot for a plugin.
func (r *PluginRegistry) GetHealth(vendor string) (*PluginHealth, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
h, ok := r.health[vendor]
if !ok {
return nil, false
}
cp := *h
return &cp, true
}
// AllHealth returns health snapshots for all plugins.
func (r *PluginRegistry) AllHealth() []PluginHealth {
r.mu.RLock()
defer r.mu.RUnlock()
result := make([]PluginHealth, 0, len(r.health))
for _, h := range r.health {
result = append(result, *h)
}
return result
}
// PluginCount returns the number of loaded plugins.
func (r *PluginRegistry) PluginCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.plugins)
}
// Vendors returns all loaded vendor names.
func (r *PluginRegistry) Vendors() []string {
r.mu.RLock()
defer r.mu.RUnlock()
result := make([]string, 0, len(r.plugins))
for v := range r.plugins {
result = append(result, v)
}
return result
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,373 @@
package shadow_ai
import (
"context"
"fmt"
"log/slog"
"sync"
"time"
)
// ShadowAIController is the main orchestrator that ties together
// detection, enforcement, SOC event emission, and statistics.
type ShadowAIController struct {
mu sync.RWMutex
registry *PluginRegistry
fallback *FallbackManager
healthChecker *HealthChecker
netDetector *NetworkDetector
behavioral *BehavioralDetector
docBridge *DocBridge
approval *ApprovalEngine
events []ShadowAIEvent // In-memory event store (bounded)
maxEvents int
socEventFn func(source, severity, category, description string, meta map[string]string) // Bridge to SOC event bus
logger *slog.Logger
}
// NewShadowAIController creates the main Shadow AI Control orchestrator.
func NewShadowAIController() *ShadowAIController {
registry := NewPluginRegistry()
RegisterDefaultPlugins(registry)
return &ShadowAIController{
registry: registry,
fallback: NewFallbackManager(registry, "detect_only"),
netDetector: NewNetworkDetector(),
behavioral: NewBehavioralDetector(100),
docBridge: NewDocBridge(),
approval: NewApprovalEngine(),
events: make([]ShadowAIEvent, 0, 1000),
maxEvents: 10000,
logger: slog.Default().With("component", "shadow-ai-controller"),
}
}
// SetSOCEventEmitter sets the function used to emit events into the SOC pipeline.
func (c *ShadowAIController) SetSOCEventEmitter(fn func(source, severity, category, description string, meta map[string]string)) {
c.mu.Lock()
defer c.mu.Unlock()
c.socEventFn = fn
}
// Configure loads plugin configuration and initializes the integration layer.
func (c *ShadowAIController) Configure(config *IntegrationConfig) error {
c.mu.Lock()
defer c.mu.Unlock()
if err := c.registry.LoadPlugins(config); err != nil {
return fmt.Errorf("failed to load plugins: %w", err)
}
c.fallback = NewFallbackManager(c.registry, config.FallbackStrategy)
c.fallback.SetEventLogger(func(event ShadowAIEvent) {
c.recordEvent(event)
})
interval := config.HealthCheckInterval
if interval <= 0 {
interval = 30 * time.Second
}
c.healthChecker = NewHealthChecker(c.registry, interval, func(vendor string, status PluginStatus, msg string) {
c.emitSOCEvent("HIGH", "integration_health", msg, map[string]string{
"vendor": vendor,
"status": string(status),
})
})
return nil
}
// StartHealthChecker starts continuous plugin health monitoring.
func (c *ShadowAIController) StartHealthChecker(ctx context.Context) {
if c.healthChecker != nil {
go c.healthChecker.Start(ctx)
}
}
// ProcessNetworkEvent analyzes a network event and enforces policy.
func (c *ShadowAIController) ProcessNetworkEvent(ctx context.Context, event NetworkEvent) *ShadowAIEvent {
detected := c.netDetector.Analyze(event)
if detected == nil {
return nil
}
// Record behavioral data.
c.behavioral.RecordAccess(event.User, event.Destination, event.DataSize)
// Attempt to block.
enforcedBy, err := c.fallback.BlockDomain(ctx, event.Destination, fmt.Sprintf("Shadow AI: %s", detected.AIService))
if err != nil {
c.logger.Error("enforcement failed", "destination", event.Destination, "error", err)
}
if enforcedBy != "" {
detected.Action = "blocked"
detected.EnforcedBy = enforcedBy
} else {
detected.Action = "detected"
}
detected.ID = genEventID()
c.recordEvent(*detected)
// Emit to SOC event bus.
c.emitSOCEvent("HIGH", "shadow_ai_usage",
fmt.Sprintf("Shadow AI access detected: %s → %s", event.User, detected.AIService),
map[string]string{
"user": event.User,
"hostname": event.Hostname,
"destination": event.Destination,
"ai_service": detected.AIService,
"action": detected.Action,
"enforced_by": detected.EnforcedBy,
},
)
return detected
}
// ScanContent scans text content for AI API keys.
func (c *ShadowAIController) ScanContent(content string) string {
return c.netDetector.SignatureDB().ScanForAPIKeys(content)
}
// ManualBlock manually blocks a domain or IP.
func (c *ShadowAIController) ManualBlock(ctx context.Context, req BlockRequest) error {
switch req.TargetType {
case "domain":
_, err := c.fallback.BlockDomain(ctx, req.Target, req.Reason)
return err
case "ip":
_, err := c.fallback.BlockIP(ctx, req.Target, req.Duration, req.Reason)
return err
case "host":
_, err := c.fallback.IsolateHost(ctx, req.Target)
return err
default:
return fmt.Errorf("unsupported target type: %s", req.TargetType)
}
}
// GetStats returns aggregate shadow AI statistics.
func (c *ShadowAIController) GetStats(timeRange string) ShadowAIStats {
c.mu.RLock()
defer c.mu.RUnlock()
cutoff := parseCutoff(timeRange)
stats := ShadowAIStats{
TimeRange: timeRange,
ByService: make(map[string]int),
ByDepartment: make(map[string]int),
}
violatorMap := make(map[string]int)
for _, e := range c.events {
if e.Timestamp.Before(cutoff) {
continue
}
stats.Total++
switch e.Action {
case "blocked":
stats.Blocked++
case "allowed", "approved":
stats.Approved++
case "pending":
stats.Pending++
}
if e.AIService != "" {
stats.ByService[e.AIService]++
}
if dept, ok := e.Metadata["department"]; ok {
stats.ByDepartment[dept]++
}
if e.UserID != "" {
violatorMap[e.UserID]++
}
}
// Build top violators list (sorted desc).
for uid, count := range violatorMap {
stats.TopViolators = append(stats.TopViolators, Violator{UserID: uid, Attempts: count})
}
// Sort by attempts descending, limit to 10.
for i := 0; i < len(stats.TopViolators); i++ {
for j := i + 1; j < len(stats.TopViolators); j++ {
if stats.TopViolators[j].Attempts > stats.TopViolators[i].Attempts {
stats.TopViolators[i], stats.TopViolators[j] = stats.TopViolators[j], stats.TopViolators[i]
}
}
}
if len(stats.TopViolators) > 10 {
stats.TopViolators = stats.TopViolators[:10]
}
return stats
}
// GetEvents returns recent shadow AI events (newest first).
func (c *ShadowAIController) GetEvents(limit int) []ShadowAIEvent {
c.mu.RLock()
defer c.mu.RUnlock()
total := len(c.events)
if total == 0 {
return nil
}
start := total - limit
if start < 0 {
start = 0
}
// Return newest first.
result := make([]ShadowAIEvent, 0, limit)
for i := total - 1; i >= start; i-- {
result = append(result, c.events[i])
}
return result
}
// GetEvent returns a single event by ID.
func (c *ShadowAIController) GetEvent(id string) (*ShadowAIEvent, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
for i := len(c.events) - 1; i >= 0; i-- {
if c.events[i].ID == id {
cp := c.events[i]
return &cp, true
}
}
return nil, false
}
// IntegrationHealth returns health status of all plugins.
func (c *ShadowAIController) IntegrationHealth() []PluginHealth {
return c.registry.AllHealth()
}
// VendorHealth returns health for a specific vendor.
func (c *ShadowAIController) VendorHealth(vendor string) (*PluginHealth, bool) {
return c.registry.GetHealth(vendor)
}
// Registry returns the plugin registry for direct access.
func (c *ShadowAIController) Registry() *PluginRegistry {
return c.registry
}
// NetworkDetector returns the network detector for configuration.
func (c *ShadowAIController) NetworkDetector() *NetworkDetector {
return c.netDetector
}
// BehavioralDetector returns the behavioral detector.
func (c *ShadowAIController) BehavioralDetector() *BehavioralDetector {
return c.behavioral
}
// DocBridge returns the document review bridge.
func (c *ShadowAIController) DocBridge() *DocBridge {
return c.docBridge
}
// ApprovalEngine returns the approval workflow engine.
func (c *ShadowAIController) ApprovalEngine() *ApprovalEngine {
return c.approval
}
// ReviewDocument scans a document and creates an approval request if needed.
func (c *ShadowAIController) ReviewDocument(docID, content, userID string) (*ScanResult, *ApprovalRequest) {
result := c.docBridge.ScanDocument(docID, content, userID)
// Create approval request based on data classification.
var req *ApprovalRequest
if result.Status != DocReviewBlocked {
req = c.approval.SubmitRequest(userID, docID, result.DataClass)
}
// Emit SOC event for tracking.
c.emitSOCEvent("MEDIUM", "shadow_ai_usage",
fmt.Sprintf("Document review: %s by %s — %s (%s)",
docID, userID, result.Status, result.DataClass),
map[string]string{
"user": userID,
"doc_id": docID,
"status": string(result.Status),
"data_class": string(result.DataClass),
"pii_count": fmt.Sprintf("%d", len(result.PIIFound)),
},
)
return result, req
}
// GenerateComplianceReport generates a compliance report for the given period.
func (c *ShadowAIController) GenerateComplianceReport(period string) ComplianceReport {
stats := c.GetStats(period)
docStats := c.docBridge.Stats()
return ComplianceReport{
GeneratedAt: time.Now(),
Period: period,
TotalInteractions: stats.Total,
BlockedAttempts: stats.Blocked,
ApprovedReviews: stats.Approved,
PIIDetected: docStats["redacted"] + docStats["blocked"],
SecretsDetected: docStats["blocked"],
AuditComplete: true,
Regulations: []string{"GDPR", "SOC2", "EU AI Act Article 15"},
}
}
// --- Internal helpers ---
func (c *ShadowAIController) recordEvent(event ShadowAIEvent) {
c.mu.Lock()
defer c.mu.Unlock()
c.events = append(c.events, event)
// Evict oldest events if over capacity.
if len(c.events) > c.maxEvents {
excess := len(c.events) - c.maxEvents
c.events = c.events[excess:]
}
}
func (c *ShadowAIController) emitSOCEvent(severity, category, description string, meta map[string]string) {
c.mu.RLock()
fn := c.socEventFn
c.mu.RUnlock()
if fn != nil {
fn("shadow-ai", severity, category, description, meta)
}
}
func parseCutoff(timeRange string) time.Time {
switch timeRange {
case "1h":
return time.Now().Add(-1 * time.Hour)
case "24h":
return time.Now().Add(-24 * time.Hour)
case "7d":
return time.Now().Add(-7 * 24 * time.Hour)
case "30d":
return time.Now().Add(-30 * 24 * time.Hour)
case "90d":
return time.Now().Add(-90 * 24 * time.Hour)
default:
return time.Now().Add(-24 * time.Hour)
}
}
var eventCounter uint64
var eventCounterMu sync.Mutex
func genEventID() string {
eventCounterMu.Lock()
eventCounter++
id := eventCounter
eventCounterMu.Unlock()
return fmt.Sprintf("sai-%d-%d", time.Now().UnixMilli(), id)
}

View file

@ -0,0 +1,159 @@
package sidecar
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"time"
domsoc "github.com/syntrex/gomcp/internal/domain/soc"
)
// BusClient sends security events to the SOC Event Bus via HTTP POST.
type BusClient struct {
baseURL string
sensorID string
apiKey string
httpClient *http.Client
maxRetries int
}
// NewBusClient creates a client for the SOC Event Bus.
func NewBusClient(baseURL, sensorID, apiKey string) *BusClient {
return &BusClient{
baseURL: baseURL,
sensorID: sensorID,
apiKey: apiKey,
httpClient: &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
MaxIdleConns: 10,
IdleConnTimeout: 90 * time.Second,
MaxIdleConnsPerHost: 5,
},
},
maxRetries: 3,
}
}
// ingestPayload matches the SOC ingest API expected JSON.
type ingestPayload struct {
Source string `json:"source"`
SensorID string `json:"sensor_id"`
SensorKey string `json:"sensor_key,omitempty"`
Severity string `json:"severity"`
Category string `json:"category"`
Subcategory string `json:"subcategory,omitempty"`
Confidence float64 `json:"confidence"`
Description string `json:"description"`
SessionID string `json:"session_id,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// SendEvent posts a SOCEvent to the Event Bus.
// Accepts context for graceful cancellation during retries (L-2 fix).
func (c *BusClient) SendEvent(ctx context.Context, evt *domsoc.SOCEvent) error {
payload := ingestPayload{
Source: string(evt.Source),
SensorID: c.sensorID,
SensorKey: c.apiKey,
Severity: string(evt.Severity),
Category: evt.Category,
Subcategory: evt.Subcategory,
Confidence: evt.Confidence,
Description: evt.Description,
SessionID: evt.SessionID,
Metadata: evt.Metadata,
}
body, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("sidecar: marshal event: %w", err)
}
url := c.baseURL + "/api/v1/soc/events"
var lastErr error
for attempt := 0; attempt <= c.maxRetries; attempt++ {
if attempt > 0 {
// Context-aware backoff: cancellable during shutdown (H-1 fix).
backoff := time.Duration(attempt*attempt) * 500 * time.Millisecond
select {
case <-ctx.Done():
return fmt.Errorf("sidecar: send cancelled during retry: %w", ctx.Err())
case <-time.After(backoff):
}
}
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("sidecar: create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
lastErr = err
slog.Warn("sidecar: bus POST failed, retrying",
"attempt", attempt+1, "error", err)
continue
}
resp.Body.Close()
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return nil
}
lastErr = fmt.Errorf("bus returned %d", resp.StatusCode)
if resp.StatusCode >= 400 && resp.StatusCode < 500 {
// Client error — don't retry.
return lastErr
}
slog.Warn("sidecar: bus returned server error, retrying",
"attempt", attempt+1, "status", resp.StatusCode)
}
return fmt.Errorf("sidecar: exhausted retries: %w", lastErr)
}
// Heartbeat sends a sensor heartbeat to the Event Bus.
func (c *BusClient) Heartbeat() error {
payload := map[string]string{
"sensor_id": c.sensorID,
}
body, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("sidecar: marshal heartbeat: %w", err)
}
url := c.baseURL + "/api/soc/sensors/heartbeat"
req, err := http.NewRequest("POST", url, bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return err
}
resp.Body.Close()
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return nil
}
return fmt.Errorf("heartbeat returned %d", resp.StatusCode)
}
// Healthy checks if the bus is reachable (M-4 fix: /healthz not /health).
func (c *BusClient) Healthy() bool {
resp, err := c.httpClient.Get(c.baseURL + "/healthz")
if err != nil {
return false
}
resp.Body.Close()
return resp.StatusCode == http.StatusOK
}

View file

@ -0,0 +1,214 @@
// Package sidecar implements the Universal Sidecar (§5.5) — a zero-dependency
// Go binary that runs alongside SENTINEL sensors, tails their STDOUT/logs,
// and pushes parsed security events to the SOC Event Bus.
package sidecar
import (
"log/slog"
"regexp"
"strconv"
"strings"
domsoc "github.com/syntrex/gomcp/internal/domain/soc"
)
// Parser converts a raw log line into a SOCEvent.
// Returns nil, false if the line is not a security event.
type Parser interface {
Parse(line string) (*domsoc.SOCEvent, bool)
}
// ── sentinel-core Parser ─────────────────────────────────────────────────────
// SentinelCoreParser parses sentinel-core detection output.
// Expected format: [DETECT] engine=<name> confidence=<float> pattern=<desc> [severity=<sev>]
type SentinelCoreParser struct{}
var coreDetectRe = regexp.MustCompile(
`\[DETECT\]\s+engine=(\S+)\s+confidence=([0-9.]+)\s+pattern=(.+?)(?:\s+severity=(\S+))?$`)
func (p *SentinelCoreParser) Parse(line string) (*domsoc.SOCEvent, bool) {
m := coreDetectRe.FindStringSubmatch(strings.TrimSpace(line))
if m == nil {
return nil, false
}
engine := m[1]
conf, _ := strconv.ParseFloat(m[2], 64)
pattern := m[3]
severity := mapConfidenceToSeverity(conf)
if m[4] != "" {
severity = domsoc.EventSeverity(strings.ToUpper(m[4]))
}
evt := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, severity, engine,
engine+": "+pattern)
evt.Confidence = conf
evt.Subcategory = pattern
return &evt, true
}
// ── shield Parser ────────────────────────────────────────────────────────────
// ShieldParser parses shield network block logs.
// Expected format: BLOCKED protocol=<proto> reason=<reason> source_ip=<ip>
type ShieldParser struct{}
var shieldBlockRe = regexp.MustCompile(
`BLOCKED\s+protocol=(\S+)\s+reason=(.+?)\s+source_ip=(\S+)`)
func (p *ShieldParser) Parse(line string) (*domsoc.SOCEvent, bool) {
m := shieldBlockRe.FindStringSubmatch(strings.TrimSpace(line))
if m == nil {
return nil, false
}
protocol := m[1]
reason := m[2]
sourceIP := m[3]
evt := domsoc.NewSOCEvent(domsoc.SourceShield, domsoc.SeverityMedium, "network_block",
"Shield blocked "+protocol+" from "+sourceIP+": "+reason)
evt.Subcategory = protocol
evt.Metadata = map[string]string{
"source_ip": sourceIP,
"protocol": protocol,
"reason": reason,
}
return &evt, true
}
// ── immune Parser ────────────────────────────────────────────────────────────
// ImmuneParser parses immune system anomaly/response logs.
// Expected format: [ANOMALY] type=<type> score=<float> detail=<text>
//
// or: [RESPONSE] action=<action> target=<target> reason=<text>
type ImmuneParser struct{}
var immuneAnomalyRe = regexp.MustCompile(
`\[ANOMALY\]\s+type=(\S+)\s+score=([0-9.]+)\s+detail=(.+)`)
var immuneResponseRe = regexp.MustCompile(
`\[RESPONSE\]\s+action=(\S+)\s+target=(\S+)\s+reason=(.+)`)
func (p *ImmuneParser) Parse(line string) (*domsoc.SOCEvent, bool) {
trimmed := strings.TrimSpace(line)
if m := immuneAnomalyRe.FindStringSubmatch(trimmed); m != nil {
anomalyType := m[1]
score, _ := strconv.ParseFloat(m[2], 64)
detail := m[3]
evt := domsoc.NewSOCEvent(domsoc.SourceImmune, mapConfidenceToSeverity(score),
"anomaly", "Immune anomaly: "+anomalyType+": "+detail)
evt.Confidence = score
evt.Subcategory = anomalyType
return &evt, true
}
if m := immuneResponseRe.FindStringSubmatch(trimmed); m != nil {
action := m[1]
target := m[2]
reason := m[3]
evt := domsoc.NewSOCEvent(domsoc.SourceImmune, domsoc.SeverityHigh,
"immune_response", "Immune response: "+action+" on "+target+": "+reason)
evt.Subcategory = action
evt.Metadata = map[string]string{
"action": action,
"target": target,
"reason": reason,
}
return &evt, true
}
return nil, false
}
// ── Generic Parser ───────────────────────────────────────────────────────────
// GenericParser uses a configurable regex with named groups.
// Named groups: "category", "severity", "description", "confidence".
type GenericParser struct {
Pattern *regexp.Regexp
Source domsoc.EventSource
}
func NewGenericParser(pattern string, source domsoc.EventSource) (*GenericParser, error) {
re, err := regexp.Compile(pattern)
if err != nil {
return nil, err
}
return &GenericParser{Pattern: re, Source: source}, nil
}
func (p *GenericParser) Parse(line string) (*domsoc.SOCEvent, bool) {
m := p.Pattern.FindStringSubmatch(strings.TrimSpace(line))
if m == nil {
return nil, false
}
names := p.Pattern.SubexpNames()
groups := map[string]string{}
for i, name := range names {
if i > 0 && name != "" {
groups[name] = m[i]
}
}
category := groups["category"]
if category == "" {
category = "generic"
}
description := groups["description"]
if description == "" {
description = line
}
severity := domsoc.SeverityMedium
if s, ok := groups["severity"]; ok && s != "" {
severity = domsoc.EventSeverity(strings.ToUpper(s))
}
confidence := 0.5
if c, ok := groups["confidence"]; ok {
if f, err := strconv.ParseFloat(c, 64); err == nil {
confidence = f
}
}
evt := domsoc.NewSOCEvent(p.Source, severity, category, description)
evt.Confidence = confidence
return &evt, true
}
// ── Helpers ──────────────────────────────────────────────────────────────────
// ParserForSensor returns the appropriate parser for a sensor type.
func ParserForSensor(sensorType string) Parser {
switch strings.ToLower(sensorType) {
case "sentinel-core":
return &SentinelCoreParser{}
case "shield":
return &ShieldParser{}
case "immune":
return &ImmuneParser{}
default:
slog.Warn("sidecar: unknown sensor type, using sentinel-core parser as fallback",
"sensor_type", sensorType)
return &SentinelCoreParser{} // fallback
}
}
func mapConfidenceToSeverity(conf float64) domsoc.EventSeverity {
switch {
case conf >= 0.9:
return domsoc.SeverityCritical
case conf >= 0.7:
return domsoc.SeverityHigh
case conf >= 0.5:
return domsoc.SeverityMedium
case conf >= 0.3:
return domsoc.SeverityLow
default:
return domsoc.SeverityInfo
}
}

View file

@ -0,0 +1,157 @@
package sidecar
import (
"context"
"fmt"
"io"
"log/slog"
"sync/atomic"
"time"
)
// Config holds sidecar runtime configuration.
type Config struct {
SensorType string // sentinel-core, shield, immune, generic
LogPath string // Path to sensor log file, or "stdin"
BusURL string // SOC Event Bus URL (e.g., http://localhost:9100)
SensorID string // Sensor registration ID
APIKey string // Sensor API key
PollInterval time.Duration // Log file poll interval
}
// Stats tracks sidecar runtime metrics (thread-safe via atomic).
type Stats struct {
LinesRead atomic.Int64
EventsSent atomic.Int64
Errors atomic.Int64
StartedAt time.Time
}
// StatsSnapshot is a non-atomic copy for reading/logging.
type StatsSnapshot struct {
LinesRead int64 `json:"lines_read"`
EventsSent int64 `json:"events_sent"`
Errors int64 `json:"errors"`
StartedAt time.Time `json:"started_at"`
}
// Sidecar is the main orchestrator: tailer → parser → bus client.
type Sidecar struct {
config Config
parser Parser
client *BusClient
tailer *Tailer
stats Stats
}
// New creates a Sidecar with the given config.
func New(cfg Config) *Sidecar {
return &Sidecar{
config: cfg,
parser: ParserForSensor(cfg.SensorType),
client: NewBusClient(cfg.BusURL, cfg.SensorID, cfg.APIKey),
tailer: NewTailer(cfg.PollInterval),
stats: Stats{StartedAt: time.Now()},
}
}
// Run starts the sidecar pipeline: tail → parse → send.
// Blocks until ctx is cancelled.
func (s *Sidecar) Run(ctx context.Context) error {
slog.Info("sidecar: starting",
"sensor_type", s.config.SensorType,
"log_path", s.config.LogPath,
"bus_url", s.config.BusURL,
"sensor_id", s.config.SensorID,
)
// Start line source.
var lines <-chan string
if s.config.LogPath == "stdin" || s.config.LogPath == "-" {
lines = s.tailer.FollowStdin(ctx)
} else {
var err error
lines, err = s.tailer.FollowFile(ctx, s.config.LogPath)
if err != nil {
return fmt.Errorf("sidecar: open log: %w", err)
}
}
// Heartbeat goroutine.
go s.heartbeatLoop(ctx)
// Main pipeline loop (shared with RunReader).
return s.processLines(ctx, lines)
}
// RunReader runs the sidecar from any io.Reader (for testing).
func (s *Sidecar) RunReader(ctx context.Context, r io.Reader) error {
lines := s.tailer.FollowReader(ctx, r)
return s.processLines(ctx, lines)
}
// processLines is the shared pipeline loop: parse → send → stats.
// Extracted to DRY between Run() and RunReader() (H-3 fix).
func (s *Sidecar) processLines(ctx context.Context, lines <-chan string) error {
for {
select {
case <-ctx.Done():
slog.Info("sidecar: shutting down",
"lines_read", s.stats.LinesRead.Load(),
"events_sent", s.stats.EventsSent.Load(),
"errors", s.stats.Errors.Load(),
)
return nil
case line, ok := <-lines:
if !ok {
slog.Info("sidecar: input closed")
return nil
}
s.stats.LinesRead.Add(1)
evt, ok := s.parser.Parse(line)
if !ok {
continue // Not a security event.
}
evt.SensorID = s.config.SensorID
if err := s.client.SendEvent(ctx, evt); err != nil {
s.stats.Errors.Add(1)
slog.Error("sidecar: send failed",
"error", err,
"category", evt.Category,
)
continue
}
s.stats.EventsSent.Add(1)
}
}
}
// GetStats returns a snapshot of current runtime metrics (thread-safe).
func (s *Sidecar) GetStats() StatsSnapshot {
return StatsSnapshot{
LinesRead: s.stats.LinesRead.Load(),
EventsSent: s.stats.EventsSent.Load(),
Errors: s.stats.Errors.Load(),
StartedAt: s.stats.StartedAt,
}
}
func (s *Sidecar) heartbeatLoop(ctx context.Context) {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.client.Heartbeat(); err != nil {
slog.Warn("sidecar: heartbeat failed", "error", err)
}
}
}
}

View file

@ -0,0 +1,306 @@
package sidecar
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
// ── Parser Tests ─────────────────────────────────────────────────────────────
func TestSentinelCoreParser(t *testing.T) {
p := &SentinelCoreParser{}
tests := []struct {
line string
wantOK bool
category string
confMin float64
}{
{"[DETECT] engine=jailbreak confidence=0.95 pattern=DAN prompt", true, "jailbreak", 0.9},
{"[DETECT] engine=injection confidence=0.6 pattern=ignore_previous", true, "injection", 0.5},
{"[DETECT] engine=exfiltration confidence=0.3 pattern=tool_call severity=HIGH", true, "exfiltration", 0.2},
{"INFO: Engine loaded successfully", false, "", 0},
{"", false, "", 0},
}
for _, tt := range tests {
evt, ok := p.Parse(tt.line)
if ok != tt.wantOK {
t.Errorf("Parse(%q) ok=%v, want %v", tt.line, ok, tt.wantOK)
continue
}
if !ok {
continue
}
if evt.Category != tt.category {
t.Errorf("Parse(%q) category=%q, want %q", tt.line, evt.Category, tt.category)
}
if evt.Confidence < tt.confMin {
t.Errorf("Parse(%q) confidence=%.2f, want >=%.2f", tt.line, evt.Confidence, tt.confMin)
}
}
}
func TestShieldParser(t *testing.T) {
p := &ShieldParser{}
tests := []struct {
line string
wantOK bool
proto string
ip string
}{
{"BLOCKED protocol=tcp reason=port_scan source_ip=192.168.1.100", true, "tcp", "192.168.1.100"},
{"BLOCKED protocol=udp reason=dns_exfil source_ip=10.0.0.5", true, "udp", "10.0.0.5"},
{"ALLOWED protocol=https from 1.2.3.4", false, "", ""},
{"", false, "", ""},
}
for _, tt := range tests {
evt, ok := p.Parse(tt.line)
if ok != tt.wantOK {
t.Errorf("Parse(%q) ok=%v, want %v", tt.line, ok, tt.wantOK)
continue
}
if !ok {
continue
}
if evt.Metadata["protocol"] != tt.proto {
t.Errorf("protocol=%q, want %q", evt.Metadata["protocol"], tt.proto)
}
if evt.Metadata["source_ip"] != tt.ip {
t.Errorf("source_ip=%q, want %q", evt.Metadata["source_ip"], tt.ip)
}
}
}
func TestImmuneParser(t *testing.T) {
p := &ImmuneParser{}
tests := []struct {
line string
wantOK bool
category string
}{
{"[ANOMALY] type=drift score=0.85 detail=behavior shift detected", true, "anomaly"},
{"[RESPONSE] action=quarantine target=session-123 reason=high risk", true, "immune_response"},
{"[INFO] system healthy", false, ""},
}
for _, tt := range tests {
evt, ok := p.Parse(tt.line)
if ok != tt.wantOK {
t.Errorf("Parse(%q) ok=%v, want %v", tt.line, ok, tt.wantOK)
continue
}
if !ok {
continue
}
if evt.Category != tt.category {
t.Errorf("category=%q, want %q", evt.Category, tt.category)
}
}
}
func TestGenericParser(t *testing.T) {
p, err := NewGenericParser(
`ALERT\s+(?P<category>\S+)\s+(?P<severity>\S+)\s+(?P<description>.+)`,
"external",
)
if err != nil {
t.Fatalf("NewGenericParser: %v", err)
}
evt, ok := p.Parse("ALERT injection HIGH suspicious sql in query string")
if !ok {
t.Fatal("expected match")
}
if evt.Category != "injection" {
t.Errorf("category=%q, want injection", evt.Category)
}
if string(evt.Severity) != "HIGH" {
t.Errorf("severity=%q, want HIGH", evt.Severity)
}
}
func TestParserForSensor(t *testing.T) {
tests := map[string]string{
"sentinel-core": "*sidecar.SentinelCoreParser",
"shield": "*sidecar.ShieldParser",
"immune": "*sidecar.ImmuneParser",
"unknown": "*sidecar.SentinelCoreParser", // fallback
}
for sensorType, wantType := range tests {
p := ParserForSensor(sensorType)
if p == nil {
t.Errorf("ParserForSensor(%q) returned nil", sensorType)
continue
}
gotType := fmt.Sprintf("%T", p)
if gotType != wantType {
t.Errorf("ParserForSensor(%q) = %s, want %s", sensorType, gotType, wantType)
}
}
}
// ── Tailer Tests ─────────────────────────────────────────────────────────────
func TestTailer_FollowReader(t *testing.T) {
input := "[DETECT] engine=jailbreak confidence=0.95 pattern=DAN\nINFO: done\n[DETECT] engine=exfil confidence=0.7 pattern=tool_call\n"
reader := strings.NewReader(input)
tailer := NewTailer(50 * time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
ch := tailer.FollowReader(ctx, reader)
var lines []string
for line := range ch {
lines = append(lines, line)
}
if len(lines) != 3 {
t.Fatalf("expected 3 lines, got %d: %v", len(lines), lines)
}
if lines[0] != "[DETECT] engine=jailbreak confidence=0.95 pattern=DAN" {
t.Errorf("line[0]=%q", lines[0])
}
}
// ── BusClient Tests ──────────────────────────────────────────────────────────
func TestBusClient_SendEvent(t *testing.T) {
var received []map[string]any
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/v1/soc/events" {
var payload map[string]any
json.NewDecoder(r.Body).Decode(&payload)
received = append(received, payload)
w.WriteHeader(http.StatusCreated)
return
}
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
client := NewBusClient(ts.URL, "test-sensor", "test-key")
p := &SentinelCoreParser{}
evt, ok := p.Parse("[DETECT] engine=jailbreak confidence=0.95 pattern=DAN")
if !ok {
t.Fatal("parse failed")
}
err := client.SendEvent(context.Background(), evt)
if err != nil {
t.Fatalf("SendEvent: %v", err)
}
if len(received) != 1 {
t.Fatalf("expected 1 received event, got %d", len(received))
}
if received[0]["source"] != "sentinel-core" {
t.Errorf("source=%v, want sentinel-core", received[0]["source"])
}
if received[0]["category"] != "jailbreak" {
t.Errorf("category=%v, want jailbreak", received[0]["category"])
}
if received[0]["sensor_id"] != "test-sensor" {
t.Errorf("sensor_id=%v, want test-sensor", received[0]["sensor_id"])
}
}
func TestBusClient_Healthy(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
client := NewBusClient(ts.URL, "s1", "k1")
if !client.Healthy() {
t.Error("expected healthy")
}
// Unreachable server.
client2 := NewBusClient("http://localhost:1", "s2", "k2")
if client2.Healthy() {
t.Error("expected unhealthy")
}
}
// ── E2E Pipeline Test ────────────────────────────────────────────────────────
func TestSidecar_E2E_Pipeline(t *testing.T) {
var receivedEvents []map[string]any
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v1/soc/events":
var payload map[string]any
json.NewDecoder(r.Body).Decode(&payload)
receivedEvents = append(receivedEvents, payload)
w.WriteHeader(http.StatusCreated)
case "/health":
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusOK)
}
}))
defer ts.Close()
input := strings.Join([]string{
"[DETECT] engine=jailbreak confidence=0.95 pattern=DAN",
"INFO: processing complete",
"[DETECT] engine=injection confidence=0.7 pattern=ignore_previous",
"DEBUG: internal state update",
"[DETECT] engine=exfiltration confidence=0.5 pattern=tool_call",
}, "\n")
cfg := Config{
SensorType: "sentinel-core",
LogPath: "stdin",
BusURL: ts.URL,
SensorID: "e2e-test-sensor",
APIKey: "test-key",
}
sc := New(cfg)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := sc.RunReader(ctx, strings.NewReader(input))
if err != nil {
t.Fatalf("RunReader: %v", err)
}
stats := sc.GetStats()
if stats.LinesRead != 5 {
t.Errorf("LinesRead=%d, want 5", stats.LinesRead)
}
if stats.EventsSent != 3 {
t.Errorf("EventsSent=%d, want 3 (3 DETECT lines, 2 skipped)", stats.EventsSent)
}
if len(receivedEvents) != 3 {
t.Fatalf("received %d events, want 3", len(receivedEvents))
}
// Verify first event.
first := receivedEvents[0]
if first["category"] != "jailbreak" {
t.Errorf("first event category=%v, want jailbreak", first["category"])
}
if first["sensor_id"] != "e2e-test-sensor" {
t.Errorf("first event sensor_id=%v, want e2e-test-sensor", first["sensor_id"])
}
}

View file

@ -0,0 +1,162 @@
package sidecar
import (
"bufio"
"context"
"io"
"log/slog"
"os"
"time"
)
// Tailer follows a log file or stdin, emitting lines via a channel.
type Tailer struct {
pollInterval time.Duration
}
// NewTailer creates a Tailer with the given poll interval for file changes.
func NewTailer(pollInterval time.Duration) *Tailer {
if pollInterval <= 0 {
pollInterval = 200 * time.Millisecond
}
return &Tailer{pollInterval: pollInterval}
}
// FollowFile tails a file, seeking to end on start.
// Sends lines on the returned channel until ctx is cancelled.
func (t *Tailer) FollowFile(ctx context.Context, path string) (<-chan string, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
// Seek to end — only process new lines.
if _, err := f.Seek(0, io.SeekEnd); err != nil {
f.Close()
return nil, err
}
ch := make(chan string, 256)
go func() {
defer f.Close()
defer close(ch)
// H-2 fix: Use Scanner with 1MB max line size to prevent OOM.
const maxLineSize = 1 << 20 // 1MB
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
for {
select {
case <-ctx.Done():
return
default:
}
if scanner.Scan() {
line := scanner.Text()
if line != "" {
select {
case ch <- line:
case <-ctx.Done():
return
}
}
continue
}
// Scanner stopped — either EOF or error.
if err := scanner.Err(); err != nil {
slog.Error("sidecar: read error", "error", err)
return
}
// EOF — wait and check for rotation.
time.Sleep(t.pollInterval)
if t.fileRotated(f, path) {
slog.Info("sidecar: log rotated, reopening", "path", path)
f.Close()
newF, err := os.Open(path)
if err != nil {
slog.Error("sidecar: reopen failed", "path", path, "error", err)
return
}
f = newF
scanner = bufio.NewScanner(f)
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
} else {
// Same file, re-create scanner at current position.
scanner = bufio.NewScanner(f)
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
}
}
}()
return ch, nil
}
// FollowStdin reads from stdin line by line.
func (t *Tailer) FollowStdin(ctx context.Context) <-chan string {
ch := make(chan string, 256)
go func() {
defer close(ch)
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
select {
case <-ctx.Done():
return
default:
}
line := scanner.Text()
if line != "" {
select {
case ch <- line:
case <-ctx.Done():
return
}
}
}
}()
return ch
}
// FollowReader reads from any io.Reader (for testing).
func (t *Tailer) FollowReader(ctx context.Context, r io.Reader) <-chan string {
ch := make(chan string, 256)
go func() {
defer close(ch)
scanner := bufio.NewScanner(r)
for scanner.Scan() {
select {
case <-ctx.Done():
return
default:
}
line := scanner.Text()
if line != "" {
select {
case ch <- line:
case <-ctx.Done():
return
}
}
}
}()
return ch
}
// fileRotated checks if the file path now points to a different inode.
func (t *Tailer) fileRotated(current *os.File, path string) bool {
curInfo, err1 := current.Stat()
newInfo, err2 := os.Stat(path)
if err1 != nil || err2 != nil {
return false
}
return !os.SameFile(curInfo, newInfo)
}

View file

@ -0,0 +1,527 @@
package soc
import (
"bytes"
"fmt"
"net/http"
"net/http/httptest"
"os"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
domsoc "github.com/syntrex/gomcp/internal/domain/soc"
"github.com/syntrex/gomcp/internal/infrastructure/audit"
"github.com/syntrex/gomcp/internal/infrastructure/sqlite"
)
// newTestServiceWithLogger creates a SOC service backed by in-memory SQLite WITH a decision logger.
func newTestServiceWithLogger(t *testing.T) *Service {
t.Helper()
db, err := sqlite.OpenMemory()
require.NoError(t, err)
repo, err := sqlite.NewSOCRepo(db)
require.NoError(t, err)
logger, err := audit.NewDecisionLogger(t.TempDir())
require.NoError(t, err)
// Close logger BEFORE TempDir cleanup (Windows file locking).
t.Cleanup(func() {
logger.Close()
db.Close()
})
return NewService(repo, logger)
}
// --- E2E: Full Pipeline (Ingest → Correlation → Incident → Playbook) ---
func TestE2E_FullPipeline_IngestToIncident(t *testing.T) {
svc := newTestServiceWithLogger(t)
// Step 1: Ingest a jailbreak event.
evt1 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityHigh, "jailbreak", "detected jailbreak attempt")
evt1.SensorID = "sensor-e2e-1"
id1, inc1, err := svc.IngestEvent(evt1)
require.NoError(t, err)
assert.NotEmpty(t, id1)
assert.Nil(t, inc1, "single event should not trigger correlation")
// Step 2: Ingest a tool_abuse event from same source — triggers SOC-CR-001.
evt2 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityCritical, "tool_abuse", "tool abuse detected")
evt2.SensorID = "sensor-e2e-1"
id2, inc2, err := svc.IngestEvent(evt2)
require.NoError(t, err)
assert.NotEmpty(t, id2)
// Correlation rule SOC-CR-001 (jailbreak + tool_abuse) should trigger an incident.
require.NotNil(t, inc2, "jailbreak + tool_abuse should create an incident")
assert.Equal(t, domsoc.SeverityCritical, inc2.Severity)
assert.Equal(t, "Multi-stage Jailbreak", inc2.Title)
assert.NotEmpty(t, inc2.ID)
assert.NotEmpty(t, inc2.Events, "incident should reference triggering events")
// Step 3: Verify incident is persisted.
gotInc, err := svc.GetIncident(inc2.ID)
require.NoError(t, err)
assert.Equal(t, inc2.ID, gotInc.ID)
// Step 4: Verify decision chain integrity.
dash, err := svc.Dashboard()
require.NoError(t, err)
assert.True(t, dash.ChainValid, "decision chain should be valid")
assert.Greater(t, dash.TotalEvents, 0)
}
func TestE2E_TemporalSequenceCorrelation(t *testing.T) {
svc := newTestServiceWithLogger(t)
// Sequence rule SOC-CR-010: auth_bypass → tool_abuse (ordered).
evt1 := domsoc.NewSOCEvent(domsoc.SourceShield, domsoc.SeverityHigh, "auth_bypass", "brute force detected")
evt1.SensorID = "sensor-seq-1"
_, _, err := svc.IngestEvent(evt1)
require.NoError(t, err)
evt2 := domsoc.NewSOCEvent(domsoc.SourceShield, domsoc.SeverityHigh, "tool_abuse", "tool escalation")
evt2.SensorID = "sensor-seq-1"
_, inc, err := svc.IngestEvent(evt2)
require.NoError(t, err)
// Should trigger either SOC-CR-010 (sequence) or another matching rule.
if inc != nil {
assert.NotEmpty(t, inc.KillChainPhase)
assert.NotEmpty(t, inc.MITREMapping)
}
}
// --- E2E: Sensor Authentication Flow ---
func TestE2E_SensorAuth_FullFlow(t *testing.T) {
svc := newTestServiceWithLogger(t)
// Configure sensor keys.
svc.SetSensorKeys(map[string]string{
"sensor-auth-1": "secret-key-1",
"sensor-auth-2": "secret-key-2",
})
// Valid auth — should succeed.
evt := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityLow, "test", "auth test")
evt.SensorID = "sensor-auth-1"
evt.SensorKey = "secret-key-1"
id, _, err := svc.IngestEvent(evt)
require.NoError(t, err)
assert.NotEmpty(t, id)
// Invalid key — should fail.
evt2 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityLow, "test", "bad key")
evt2.SensorID = "sensor-auth-1"
evt2.SensorKey = "wrong-key"
_, _, err = svc.IngestEvent(evt2)
require.Error(t, err)
assert.Contains(t, err.Error(), "auth")
// Missing SensorID — should fail (S-1 fix).
evt3 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityLow, "test", "no sensor id")
_, _, err = svc.IngestEvent(evt3)
require.Error(t, err)
assert.Contains(t, err.Error(), "sensor_id required")
// Unknown sensor — should fail.
evt4 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityLow, "test", "unknown sensor")
evt4.SensorID = "sensor-unknown"
evt4.SensorKey = "whatever"
_, _, err = svc.IngestEvent(evt4)
require.Error(t, err)
assert.Contains(t, err.Error(), "auth")
}
// --- E2E: Drain Mode ---
func TestE2E_DrainMode_RejectsNewEvents(t *testing.T) {
svc := newTestServiceWithLogger(t)
// Ingest works before drain.
evt := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityLow, "test", "pre-drain")
evt.SensorID = "sensor-drain"
_, _, err := svc.IngestEvent(evt)
require.NoError(t, err)
// Activate drain mode.
svc.Drain()
assert.True(t, svc.IsDraining())
// New events should be rejected.
evt2 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityLow, "test", "during-drain")
evt2.SensorID = "sensor-drain"
_, _, err = svc.IngestEvent(evt2)
require.Error(t, err)
assert.Contains(t, err.Error(), "draining")
// Resume.
svc.Resume()
assert.False(t, svc.IsDraining())
// Events should work again.
evt3 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityLow, "test", "post-drain")
evt3.SensorID = "sensor-drain"
_, _, err = svc.IngestEvent(evt3)
require.NoError(t, err)
}
// --- E2E: Webhook Delivery ---
func TestE2E_WebhookFiredOnIncident(t *testing.T) {
svc := newTestServiceWithLogger(t)
// Set up a test webhook server.
var mu sync.Mutex
var received []string
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
received = append(received, r.URL.Path)
mu.Unlock()
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
svc.SetWebhookConfig(WebhookConfig{
Endpoints: []string{ts.URL + "/webhook"},
MaxRetries: 1,
TimeoutSec: 5,
})
// Trigger an incident via correlation.
evt1 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityHigh, "jailbreak", "jailbreak e2e")
evt1.SensorID = "sensor-wh"
svc.IngestEvent(evt1)
evt2 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityCritical, "tool_abuse", "tool abuse e2e")
evt2.SensorID = "sensor-wh"
_, inc, err := svc.IngestEvent(evt2)
require.NoError(t, err)
if inc != nil {
// Give the async webhook goroutine time to fire.
time.Sleep(200 * time.Millisecond)
mu.Lock()
assert.GreaterOrEqual(t, len(received), 1, "webhook should have been called")
mu.Unlock()
}
}
// --- E2E: Verdict Flow ---
func TestE2E_VerdictFlow(t *testing.T) {
svc := newTestServiceWithLogger(t)
// Create an incident via correlation.
evt1 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityHigh, "jailbreak", "verdict test 1")
evt1.SensorID = "sensor-vd"
svc.IngestEvent(evt1)
evt2 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityCritical, "tool_abuse", "verdict test 2")
evt2.SensorID = "sensor-vd"
_, inc, _ := svc.IngestEvent(evt2)
if inc == nil {
t.Skip("no incident created — correlation rules may not match with current sliding window state")
}
// Verify initial status is OPEN.
got, err := svc.GetIncident(inc.ID)
require.NoError(t, err)
assert.Equal(t, domsoc.StatusOpen, got.Status)
// Update to INVESTIGATING.
err = svc.UpdateVerdict(inc.ID, domsoc.StatusInvestigating)
require.NoError(t, err)
got, _ = svc.GetIncident(inc.ID)
assert.Equal(t, domsoc.StatusInvestigating, got.Status)
// Update to RESOLVED.
err = svc.UpdateVerdict(inc.ID, domsoc.StatusResolved)
require.NoError(t, err)
got, _ = svc.GetIncident(inc.ID)
assert.Equal(t, domsoc.StatusResolved, got.Status)
}
// --- E2E: Analytics Report ---
func TestE2E_AnalyticsReport(t *testing.T) {
svc := newTestServiceWithLogger(t)
// Ingest several events.
categories := []string{"jailbreak", "injection", "exfiltration", "auth_bypass", "tool_abuse"}
for i, cat := range categories {
evt := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityHigh, cat, fmt.Sprintf("analytics test %d", i))
evt.SensorID = "sensor-analytics"
svc.IngestEvent(evt)
}
report, err := svc.Analytics(24)
require.NoError(t, err)
assert.NotNil(t, report)
assert.Greater(t, len(report.TopCategories), 0)
assert.Greater(t, len(report.TopSources), 0)
assert.GreaterOrEqual(t, report.EventsPerHour, float64(0))
}
// --- E2E: Multi-Sensor Concurrent Ingest ---
func TestE2E_ConcurrentIngest(t *testing.T) {
svc := newTestServiceWithLogger(t)
var wg sync.WaitGroup
errors := make([]error, 0)
var mu sync.Mutex
// 10 sensors × 10 events each = 100 concurrent ingests.
for s := 0; s < 10; s++ {
wg.Add(1)
go func(sensorNum int) {
defer wg.Done()
for i := 0; i < 10; i++ {
evt := domsoc.NewSOCEvent(
domsoc.SourceSentinelCore,
domsoc.SeverityLow,
"test",
fmt.Sprintf("concurrent sensor-%d event-%d", sensorNum, i),
)
evt.SensorID = fmt.Sprintf("sensor-conc-%d", sensorNum)
_, _, err := svc.IngestEvent(evt)
if err != nil {
mu.Lock()
errors = append(errors, err)
mu.Unlock()
}
}
}(s)
}
wg.Wait()
// Some events may be rate-limited (100 events/sec per sensor),
// but there should be no panics or data corruption.
dash, err := svc.Dashboard()
require.NoError(t, err)
assert.Greater(t, dash.TotalEvents, 0, "at least some events should have been ingested")
}
// --- E2E: Lattice TSA Chain Violation (SOC-CR-012) ---
func TestE2E_TSAChainViolation(t *testing.T) {
svc := newTestServiceWithLogger(t)
// SOC-CR-012 requires: auth_bypass → tool_abuse → exfiltration within 15 min.
events := []struct {
category string
severity domsoc.EventSeverity
}{
{"auth_bypass", domsoc.SeverityHigh},
{"tool_abuse", domsoc.SeverityHigh},
{"exfiltration", domsoc.SeverityCritical},
}
var lastInc *domsoc.Incident
for _, e := range events {
evt := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, e.severity, e.category, "TSA chain test: "+e.category)
evt.SensorID = "sensor-tsa"
_, inc, err := svc.IngestEvent(evt)
require.NoError(t, err)
if inc != nil {
lastInc = inc
}
}
// The TSA chain (auth_bypass + tool_abuse + exfiltration) should trigger
// SOC-CR-012 or another matching rule.
require.NotNil(t, lastInc, "TSA chain (auth_bypass → tool_abuse → exfiltration) should create an incident")
assert.Equal(t, domsoc.SeverityCritical, lastInc.Severity)
assert.NotEmpty(t, lastInc.MITREMapping)
// Verify incident is persisted.
got, err := svc.GetIncident(lastInc.ID)
require.NoError(t, err)
assert.Equal(t, lastInc.ID, got.ID)
}
// --- E2E: Zero-G Mode Excludes Playbook Auto-Response ---
func TestE2E_ZeroGExcludedFromAutoResponse(t *testing.T) {
svc := newTestServiceWithLogger(t)
// Set up a test webhook server to track playbook webhook notifications.
var mu sync.Mutex
var webhookCalls int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
webhookCalls++
mu.Unlock()
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
svc.SetWebhookConfig(WebhookConfig{
Endpoints: []string{ts.URL + "/webhook"},
MaxRetries: 1,
TimeoutSec: 5,
})
// Ingest jailbreak + tool_abuse with ZeroGMode=true.
// This should trigger correlation (incident created) but NOT playbooks.
evt1 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityHigh, "jailbreak", "zero-g jailbreak test")
evt1.SensorID = "sensor-zg"
evt1.ZeroGMode = true
_, _, err := svc.IngestEvent(evt1)
require.NoError(t, err)
evt2 := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityCritical, "tool_abuse", "zero-g tool abuse test")
evt2.SensorID = "sensor-zg"
evt2.ZeroGMode = true
_, inc, err := svc.IngestEvent(evt2)
require.NoError(t, err)
// Correlation should still run — incident should be created.
if inc != nil {
assert.Equal(t, domsoc.SeverityCritical, inc.Severity)
// Wait for any async webhook goroutines.
time.Sleep(200 * time.Millisecond)
// Webhook should NOT have been called (playbook skipped for Zero-G).
mu.Lock()
assert.Equal(t, 0, webhookCalls, "webhooks should NOT fire for Zero-G events — playbook must be skipped")
mu.Unlock()
}
// Verify decision log records the PLAYBOOK_SKIPPED:ZERO_G entry.
logPath := svc.DecisionLogPath()
if logPath != "" {
valid, broken, err := audit.VerifyChainFromFile(logPath)
require.NoError(t, err)
assert.Equal(t, 0, broken, "decision chain should be intact")
assert.Greater(t, valid, 0, "should have decision entries")
}
}
// --- E2E: Decision Logger Tamper Detection ---
func TestE2E_DecisionLoggerTampering(t *testing.T) {
svc := newTestServiceWithLogger(t)
// Ingest several events to build up a decision chain.
for i := 0; i < 10; i++ {
evt := domsoc.NewSOCEvent(
domsoc.SourceSentinelCore,
domsoc.SeverityLow,
"test",
fmt.Sprintf("tamper test event %d", i),
)
evt.SensorID = "sensor-tamper"
_, _, err := svc.IngestEvent(evt)
require.NoError(t, err)
}
// Step 1: Verify chain is valid.
logPath := svc.DecisionLogPath()
require.NotEmpty(t, logPath, "decision log path should be set")
validCount, brokenLine, err := audit.VerifyChainFromFile(logPath)
require.NoError(t, err)
assert.Equal(t, 0, brokenLine, "chain should be intact before tampering")
assert.GreaterOrEqual(t, validCount, 10, "should have at least 10 decision entries")
// Step 2: Tamper with the log file — modify a line mid-chain.
data, err := os.ReadFile(logPath)
require.NoError(t, err)
lines := bytes.Split(data, []byte("\n"))
if len(lines) > 5 {
// Corrupt line 5 by altering content.
lines[4] = []byte("TAMPERED|2026-01-01T00:00:00Z|SOC|FAKE|fake_reason|0000000000")
err = os.WriteFile(logPath, bytes.Join(lines, []byte("\n")), 0644)
require.NoError(t, err)
// Step 3: Verify chain detects the tamper.
_, brokenLine2, err2 := audit.VerifyChainFromFile(logPath)
require.NoError(t, err2)
assert.Greater(t, brokenLine2, 0, "chain should detect tampering — broken line reported")
}
}
// --- E2E: Cross-Sensor Session Correlation (SOC-CR-011) ---
func TestE2E_CrossSensorSessionCorrelation(t *testing.T) {
svc := newTestServiceWithLogger(t)
// SOC-CR-011 requires 3+ events from different sensors with same session_id.
sessionID := "session-xsensor-e2e-001"
sources := []struct {
source domsoc.EventSource
sensor string
category string
}{
{domsoc.SourceShield, "sensor-shield-1", "auth_bypass"},
{domsoc.SourceSentinelCore, "sensor-core-1", "jailbreak"},
{domsoc.SourceImmune, "sensor-immune-1", "exfiltration"},
}
var lastInc *domsoc.Incident
for _, s := range sources {
evt := domsoc.NewSOCEvent(s.source, domsoc.SeverityHigh, s.category, "cross-sensor test: "+s.category)
evt.SensorID = s.sensor
evt.SessionID = sessionID
_, inc, err := svc.IngestEvent(evt)
require.NoError(t, err)
if inc != nil {
lastInc = inc
}
}
// After 3 events from different sensors/sources with same session_id,
// at least one correlation rule should have matched.
require.NotNil(t, lastInc, "cross-sensor session attack (3 sources, same session_id) should create incident")
assert.NotEmpty(t, lastInc.ID)
assert.NotEmpty(t, lastInc.Events, "incident should reference triggering events")
}
// --- E2E: Crescendo Escalation (SOC-CR-015) ---
func TestE2E_CrescendoEscalation(t *testing.T) {
svc := newTestServiceWithLogger(t)
// SOC-CR-015: 3+ jailbreak events with ascending severity within 15 min.
severities := []domsoc.EventSeverity{
domsoc.SeverityLow,
domsoc.SeverityMedium,
domsoc.SeverityHigh,
}
var lastInc *domsoc.Incident
for i, sev := range severities {
evt := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, sev, "jailbreak",
fmt.Sprintf("crescendo jailbreak attempt %d", i+1))
evt.SensorID = "sensor-crescendo"
_, inc, err := svc.IngestEvent(evt)
require.NoError(t, err)
if inc != nil {
lastInc = inc
}
}
// The ascending severity pattern (LOW→MEDIUM→HIGH) should trigger SOC-CR-015.
require.NotNil(t, lastInc, "crescendo pattern (LOW→MEDIUM→HIGH jailbreaks) should create incident")
assert.Equal(t, domsoc.SeverityCritical, lastInc.Severity)
assert.Contains(t, lastInc.MITREMapping, "T1059")
}

View file

@ -0,0 +1,100 @@
package soc
import (
"fmt"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
domsoc "github.com/syntrex/gomcp/internal/domain/soc"
"github.com/syntrex/gomcp/internal/infrastructure/audit"
"github.com/syntrex/gomcp/internal/infrastructure/sqlite"
)
// newBenchService creates a minimal SOC service for benchmarking.
// Disables rate limiting to measure raw pipeline throughput.
func newBenchService(b *testing.B) *Service {
b.Helper()
tmpDir := b.TempDir()
dbPath := tmpDir + "/bench.db"
db, err := sqlite.Open(dbPath)
require.NoError(b, err)
b.Cleanup(func() { db.Close() })
repo, err := sqlite.NewSOCRepo(db)
require.NoError(b, err)
logger, err := audit.NewDecisionLogger(tmpDir)
require.NoError(b, err)
b.Cleanup(func() { logger.Close() })
svc := NewService(repo, logger)
svc.DisableRateLimit() // benchmarks measure throughput, not rate limiting
return svc
}
// BenchmarkIngestEvent measures single-event pipeline throughput.
func BenchmarkIngestEvent(b *testing.B) {
svc := newBenchService(b)
b.ResetTimer()
for i := 0; i < b.N; i++ {
event := domsoc.NewSOCEvent(domsoc.SourceShield, domsoc.SeverityMedium, "injection",
fmt.Sprintf("Bench event #%d", i))
event.ID = fmt.Sprintf("bench-evt-%d", i)
_, _, err := svc.IngestEvent(event)
if err != nil {
b.Fatal(err)
}
}
}
// BenchmarkIngestEvent_WithCorrelation measures pipeline with correlation active.
// Pre-loads events to trigger correlation matching.
func BenchmarkIngestEvent_WithCorrelation(b *testing.B) {
svc := newBenchService(b)
// Pre-load events to make correlation rules meaningful.
for i := 0; i < 50; i++ {
event := domsoc.NewSOCEvent(domsoc.SourceShield, domsoc.SeverityHigh, "jailbreak",
fmt.Sprintf("Pre-load jailbreak #%d", i))
event.ID = fmt.Sprintf("preload-%d", i)
svc.IngestEvent(event)
time.Sleep(time.Microsecond)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
event := domsoc.NewSOCEvent(domsoc.SourceShield, domsoc.SeverityHigh, "jailbreak",
fmt.Sprintf("Corr bench event #%d", i))
event.ID = fmt.Sprintf("bench-corr-%d", i)
_, _, err := svc.IngestEvent(event)
if err != nil {
b.Fatal(err)
}
}
}
// BenchmarkIngestEvent_Parallel measures concurrent ingest throughput.
func BenchmarkIngestEvent_Parallel(b *testing.B) {
svc := newBenchService(b)
var counter atomic.Int64
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
n := counter.Add(1)
event := domsoc.NewSOCEvent(domsoc.SourceShield, domsoc.SeverityLow, "jailbreak",
fmt.Sprintf("Parallel bench #%d", n))
event.ID = fmt.Sprintf("bench-par-%d", n)
_, _, err := svc.IngestEvent(event)
if err != nil {
b.Fatal(err)
}
}
})
}

View file

@ -0,0 +1,153 @@
package soc
import (
"fmt"
"math"
"sort"
"sync"
"sync/atomic"
"testing"
"time"
domsoc "github.com/syntrex/gomcp/internal/domain/soc"
"github.com/syntrex/gomcp/internal/infrastructure/audit"
"github.com/syntrex/gomcp/internal/infrastructure/sqlite"
"github.com/stretchr/testify/require"
)
// TestLoadTest_SustainedThroughput measures SOC pipeline throughput and latency
// under sustained concurrent load. Reports p50/p95/p99 latencies and events/sec.
func TestLoadTest_SustainedThroughput(t *testing.T) {
if testing.Short() {
t.Skip("skipping load test in short mode")
}
// Setup service with file-based SQLite for concurrency safety.
tmpDir := t.TempDir()
db, err := sqlite.Open(tmpDir + "/loadtest.db")
require.NoError(t, err)
repo, err := sqlite.NewSOCRepo(db)
require.NoError(t, err)
logger, err := audit.NewDecisionLogger(tmpDir)
require.NoError(t, err)
t.Cleanup(func() {
logger.Close()
db.Close()
})
svc := NewService(repo, logger)
svc.DisableRateLimit() // bypass rate limiter for raw throughput
// Load test parameters.
const (
numWorkers = 16
eventsPerWkr = 200
totalEvents = numWorkers * eventsPerWkr
)
categories := []string{"jailbreak", "injection", "exfiltration", "auth_bypass", "tool_abuse"}
sources := []domsoc.EventSource{domsoc.SourceSentinelCore, domsoc.SourceShield, domsoc.SourceGoMCP}
var (
wg sync.WaitGroup
latencies = make([]time.Duration, totalEvents)
errors int64
incidents int64
)
start := time.Now()
for w := 0; w < numWorkers; w++ {
wg.Add(1)
go func(workerID int) {
defer wg.Done()
for i := 0; i < eventsPerWkr; i++ {
idx := workerID*eventsPerWkr + i
evt := domsoc.NewSOCEvent(
sources[idx%len(sources)],
domsoc.SeverityHigh,
categories[idx%len(categories)],
fmt.Sprintf("load-test w%d-e%d", workerID, i),
)
evt.SensorID = fmt.Sprintf("load-sensor-%d", workerID)
t0 := time.Now()
_, inc, err := svc.IngestEvent(evt)
latencies[idx] = time.Since(t0)
if err != nil {
atomic.AddInt64(&errors, 1)
}
if inc != nil {
atomic.AddInt64(&incidents, 1)
}
}
}(w)
}
wg.Wait()
totalDuration := time.Since(start)
// Compute latency percentiles.
sort.Slice(latencies, func(i, j int) bool { return latencies[i] < latencies[j] })
p50 := percentile(latencies, 50)
p95 := percentile(latencies, 95)
p99 := percentile(latencies, 99)
mean := meanDuration(latencies)
eventsPerSec := float64(totalEvents) / totalDuration.Seconds()
// Report results.
t.Logf("═══════════════════════════════════════════════")
t.Logf(" SENTINEL SOC Load Test Results")
t.Logf("═══════════════════════════════════════════════")
t.Logf(" Workers: %d", numWorkers)
t.Logf(" Events/worker: %d", eventsPerWkr)
t.Logf(" Total events: %d", totalEvents)
t.Logf(" Duration: %s", totalDuration.Round(time.Millisecond))
t.Logf(" Throughput: %.0f events/sec", eventsPerSec)
t.Logf("───────────────────────────────────────────────")
t.Logf(" Mean: %s", mean.Round(time.Microsecond))
t.Logf(" P50: %s", p50.Round(time.Microsecond))
t.Logf(" P95: %s", p95.Round(time.Microsecond))
t.Logf(" P99: %s", p99.Round(time.Microsecond))
t.Logf(" Min: %s", latencies[0].Round(time.Microsecond))
t.Logf(" Max: %s", latencies[len(latencies)-1].Round(time.Microsecond))
t.Logf("───────────────────────────────────────────────")
t.Logf(" Errors: %d (%.1f%%)", errors, float64(errors)/float64(totalEvents)*100)
t.Logf(" Incidents: %d", incidents)
t.Logf("═══════════════════════════════════════════════")
// Assertions: basic sanity checks.
require.Less(t, float64(errors)/float64(totalEvents), 0.05, "error rate should be < 5%%")
require.Greater(t, eventsPerSec, float64(100), "should sustain > 100 events/sec")
}
func percentile(sorted []time.Duration, p int) time.Duration {
if len(sorted) == 0 {
return 0
}
idx := int(math.Ceil(float64(p)/100.0*float64(len(sorted)))) - 1
if idx < 0 {
idx = 0
}
if idx >= len(sorted) {
idx = len(sorted) - 1
}
return sorted[idx]
}
func meanDuration(ds []time.Duration) time.Duration {
if len(ds) == 0 {
return 0
}
var total time.Duration
for _, d := range ds {
total += d
}
return total / time.Duration(len(ds))
}

File diff suppressed because it is too large Load diff

View file

@ -143,7 +143,7 @@ func TestRunPlaybook_IncidentNotFound(t *testing.T) {
svc := newTestService(t)
// Use a valid playbook ID from defaults.
_, err := svc.RunPlaybook("pb-auto-block-jailbreak", "nonexistent-inc")
_, err := svc.RunPlaybook("pb-block-jailbreak", "nonexistent-inc")
require.Error(t, err)
assert.Contains(t, err.Error(), "incident not found")
}

View file

@ -0,0 +1,255 @@
package soc
import (
"encoding/json"
"log/slog"
"net/http"
"strings"
"time"
)
// STIXBundle represents a STIX 2.1 bundle (simplified).
type STIXBundle struct {
Type string `json:"type"` // "bundle"
ID string `json:"id"`
Objects []STIXObject `json:"objects"`
}
// STIXObject represents a generic STIX 2.1 object.
type STIXObject struct {
Type string `json:"type"` // indicator, malware, attack-pattern, etc.
ID string `json:"id"`
Created time.Time `json:"created"`
Modified time.Time `json:"modified"`
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Pattern string `json:"pattern,omitempty"` // STIX pattern (indicators)
PatternType string `json:"pattern_type,omitempty"` // stix, pcre, sigma
ValidFrom time.Time `json:"valid_from,omitempty"`
Labels []string `json:"labels,omitempty"`
// Kill chain phases for attack-pattern objects.
KillChainPhases []struct {
KillChainName string `json:"kill_chain_name"`
PhaseName string `json:"phase_name"`
} `json:"kill_chain_phases,omitempty"`
// External references (CVE, etc.)
ExternalReferences []struct {
SourceName string `json:"source_name"`
ExternalID string `json:"external_id,omitempty"`
URL string `json:"url,omitempty"`
Description string `json:"description,omitempty"`
} `json:"external_references,omitempty"`
}
// STIXFeedConfig configures automatic STIX feed polling.
type STIXFeedConfig struct {
Name string `json:"name"` // Feed name (e.g., "OTX", "MISP")
URL string `json:"url"` // TAXII or HTTP feed URL
APIKey string `json:"api_key"` // Authentication key
Headers map[string]string `json:"headers"` // Additional headers
Interval time.Duration `json:"interval"` // Poll interval (default: 1h)
Enabled bool `json:"enabled"`
}
// FeedSync syncs IOCs from STIX/TAXII feeds into the ThreatIntelStore.
type FeedSync struct {
feeds []STIXFeedConfig
store *ThreatIntelStore
client *http.Client
}
// NewFeedSync creates a feed synchronizer.
func NewFeedSync(store *ThreatIntelStore, feeds []STIXFeedConfig) *FeedSync {
return &FeedSync{
feeds: feeds,
store: store,
client: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// Start begins polling all enabled feeds in the background.
func (f *FeedSync) Start(done <-chan struct{}) {
for _, feed := range f.feeds {
if !feed.Enabled {
continue
}
go f.pollFeed(feed, done)
}
}
// pollFeed periodically fetches and processes a single STIX feed.
func (f *FeedSync) pollFeed(feed STIXFeedConfig, done <-chan struct{}) {
interval := feed.Interval
if interval == 0 {
interval = time.Hour
}
slog.Info("stix feed started", "feed", feed.Name, "url", feed.URL, "interval", interval)
// Initial fetch.
f.fetchFeed(feed)
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-done:
slog.Info("stix feed stopped", "feed", feed.Name)
return
case <-ticker.C:
f.fetchFeed(feed)
}
}
}
// fetchFeed performs a single HTTP GET and processes the STIX bundle.
func (f *FeedSync) fetchFeed(feed STIXFeedConfig) {
req, err := http.NewRequest(http.MethodGet, feed.URL, nil)
if err != nil {
slog.Error("stix feed: request error", "feed", feed.Name, "error", err)
return
}
req.Header.Set("Accept", "application/stix+json;version=2.1")
if feed.APIKey != "" {
req.Header.Set("X-OTX-API-KEY", feed.APIKey)
}
for k, v := range feed.Headers {
req.Header.Set(k, v)
}
resp, err := f.client.Do(req)
if err != nil {
slog.Error("stix feed: fetch error", "feed", feed.Name, "error", err)
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
slog.Warn("stix feed: non-200 response", "feed", feed.Name, "status", resp.StatusCode)
return
}
var bundle STIXBundle
if err := json.NewDecoder(resp.Body).Decode(&bundle); err != nil {
slog.Error("stix feed: decode error", "feed", feed.Name, "error", err)
return
}
imported := f.processBundle(feed.Name, bundle)
slog.Info("stix feed synced",
"feed", feed.Name,
"objects", len(bundle.Objects),
"iocs_imported", imported,
)
}
// processBundle extracts IOCs from STIX indicators and adds to the store.
func (f *FeedSync) processBundle(feedName string, bundle STIXBundle) int {
imported := 0
for _, obj := range bundle.Objects {
if obj.Type != "indicator" || obj.Pattern == "" {
continue
}
ioc := stixPatternToIOC(obj)
if ioc == nil {
continue
}
ioc.Source = feedName
ioc.Tags = obj.Labels
f.store.AddIOC(*ioc)
imported++
}
return imported
}
// stixPatternToIOC converts a STIX indicator pattern to our IOC format.
// Supports: [file:hashes.'SHA-256' = '...'], [ipv4-addr:value = '...'],
// [domain-name:value = '...'], [url:value = '...']
func stixPatternToIOC(obj STIXObject) *IOC {
pattern := obj.Pattern
now := obj.Modified
if now.IsZero() {
now = obj.Created
}
ioc := &IOC{
Value: "",
Severity: "medium",
FirstSeen: now,
LastSeen: now,
Confidence: 0.7,
}
switch {
case strings.Contains(pattern, "file:hashes"):
ioc.Type = IOCTypeHash
ioc.Value = extractSTIXValue(pattern)
case strings.Contains(pattern, "ipv4-addr:value"):
ioc.Type = IOCTypeIP
ioc.Value = extractSTIXValue(pattern)
case strings.Contains(pattern, "domain-name:value"):
ioc.Type = IOCTypeDomain
ioc.Value = extractSTIXValue(pattern)
case strings.Contains(pattern, "url:value"):
ioc.Type = IOCTypeURL
ioc.Value = extractSTIXValue(pattern)
default:
return nil
}
if ioc.Value == "" {
return nil
}
// Derive severity from STIX labels.
for _, label := range obj.Labels {
switch {
case strings.Contains(label, "anomalous-activity"):
ioc.Severity = "low"
case strings.Contains(label, "malicious-activity"):
ioc.Severity = "critical"
case strings.Contains(label, "attribution"):
ioc.Severity = "high"
}
}
return ioc
}
// extractSTIXValue pulls the quoted value from a STIX pattern like:
// [ipv4-addr:value = '192.168.1.1']
// [file:hashes.'SHA-256' = 'e3b0c44...']
func extractSTIXValue(pattern string) string {
// Anchor on "= '" to skip any earlier quotes (e.g., hashes.'SHA-256').
eqIdx := strings.Index(pattern, "= '")
if eqIdx < 0 {
return ""
}
start := eqIdx + 3 // skip "= '"
end := strings.Index(pattern[start:], "'")
if end < 0 {
return ""
}
return pattern[start : start+end]
}
// DefaultOTXFeed returns a pre-configured AlienVault OTX feed config.
func DefaultOTXFeed(apiKey string) STIXFeedConfig {
return STIXFeedConfig{
Name: "AlienVault OTX",
URL: "https://otx.alienvault.com/api/v1/pulses/subscribed",
APIKey: apiKey,
Interval: time.Hour,
Enabled: apiKey != "",
Headers: map[string]string{
"X-OTX-API-KEY": apiKey,
},
}
}
// IOC type is defined in threat_intel.go — this file uses it directly.

View file

@ -0,0 +1,137 @@
package soc
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- stixPatternToIOC ---
func TestSTIXPatternToIOC_IPv4(t *testing.T) {
obj := STIXObject{
Type: "indicator",
Pattern: "[ipv4-addr:value = '192.168.1.1']",
Modified: time.Now(),
}
ioc := stixPatternToIOC(obj)
require.NotNil(t, ioc, "should parse IPv4 pattern")
assert.Equal(t, IOCTypeIP, ioc.Type)
assert.Equal(t, "192.168.1.1", ioc.Value)
assert.Equal(t, "medium", ioc.Severity)
assert.False(t, ioc.FirstSeen.IsZero(), "FirstSeen must be set")
}
func TestSTIXPatternToIOC_Hash(t *testing.T) {
hash := "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
obj := STIXObject{
Type: "indicator",
Pattern: "[file:hashes.'SHA-256' = '" + hash + "']",
Modified: time.Now(),
Labels: []string{"malicious-activity"},
}
ioc := stixPatternToIOC(obj)
require.NotNil(t, ioc, "should parse hash pattern")
assert.Equal(t, IOCTypeHash, ioc.Type)
assert.Equal(t, hash, ioc.Value)
assert.Equal(t, "critical", ioc.Severity, "malicious-activity label → critical")
}
func TestSTIXPatternToIOC_Domain(t *testing.T) {
obj := STIXObject{
Type: "indicator",
Pattern: "[domain-name:value = 'evil.example.com']",
Modified: time.Now(),
Labels: []string{"attribution"},
}
ioc := stixPatternToIOC(obj)
require.NotNil(t, ioc)
assert.Equal(t, IOCTypeDomain, ioc.Type)
assert.Equal(t, "evil.example.com", ioc.Value)
assert.Equal(t, "high", ioc.Severity, "attribution label → high")
}
func TestSTIXPatternToIOC_Unsupported(t *testing.T) {
obj := STIXObject{
Type: "indicator",
Pattern: "[email-addr:value = 'attacker@evil.com']",
Modified: time.Now(),
}
ioc := stixPatternToIOC(obj)
assert.Nil(t, ioc, "unsupported pattern type should return nil")
}
func TestSTIXPatternToIOC_FallbackToCreated(t *testing.T) {
created := time.Date(2026, 1, 15, 0, 0, 0, 0, time.UTC)
obj := STIXObject{
Type: "indicator",
Pattern: "[ipv4-addr:value = '10.0.0.1']",
Created: created,
// Modified is zero → should fall back to Created
}
ioc := stixPatternToIOC(obj)
require.NotNil(t, ioc)
assert.Equal(t, created, ioc.FirstSeen, "should fall back to Created when Modified is zero")
}
// --- extractSTIXValue ---
func TestExtractSTIXValue(t *testing.T) {
tests := []struct {
name string
pattern string
want string
}{
{"ipv4", "[ipv4-addr:value = '1.2.3.4']", "1.2.3.4"},
{"domain", "[domain-name:value = 'evil.com']", "evil.com"},
{"hash", "[file:hashes.'SHA-256' = 'abc123']", "abc123"},
{"empty_no_quotes", "[ipv4-addr:value = ]", ""},
{"single_quote_only", "'", ""},
{"empty_string", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractSTIXValue(tt.pattern)
assert.Equal(t, tt.want, got)
})
}
}
// --- processBundle ---
func TestProcessBundle_FiltersNonIndicators(t *testing.T) {
store := NewThreatIntelStore()
fs := NewFeedSync(store, nil)
bundle := STIXBundle{
Type: "bundle",
ID: "bundle--test",
Objects: []STIXObject{
{Type: "indicator", Pattern: "[ipv4-addr:value = '10.0.0.1']", Modified: time.Now()},
{Type: "malware", Name: "BadMalware"}, // should be skipped
{Type: "indicator", Pattern: ""}, // empty pattern → skipped
{Type: "attack-pattern", Name: "Phish"}, // should be skipped
{Type: "indicator", Pattern: "[domain-name:value = 'bad.com']", Modified: time.Now()},
},
}
imported := fs.processBundle("test-feed", bundle)
assert.Equal(t, 2, imported, "should import only 2 valid indicators")
assert.Equal(t, 2, store.TotalIOCs, "store should have 2 IOCs")
}
// --- DefaultOTXFeed ---
func TestDefaultOTXFeed(t *testing.T) {
feed := DefaultOTXFeed("test-key-123")
assert.Equal(t, "AlienVault OTX", feed.Name)
assert.True(t, feed.Enabled, "should be enabled when key provided")
assert.Contains(t, feed.URL, "otx.alienvault.com")
assert.Equal(t, time.Hour, feed.Interval)
assert.Equal(t, "test-key-123", feed.Headers["X-OTX-API-KEY"])
disabled := DefaultOTXFeed("")
assert.False(t, disabled.Enabled, "should be disabled when key is empty")
}

View file

@ -6,8 +6,8 @@ import (
"bytes"
"encoding/json"
"fmt"
"log"
"math/rand"
"log/slog"
"math/rand/v2"
"net/http"
"sync"
"time"
@ -58,9 +58,9 @@ type WebhookNotifier struct {
client *http.Client
enabled bool
// Stats
Sent int64 `json:"sent"`
Failed int64 `json:"failed"`
// Stats (unexported — access via Stats() method)
sent int64
failed int64
}
// NewWebhookNotifier creates a notifier with the given config.
@ -80,23 +80,7 @@ func NewWebhookNotifier(config WebhookConfig) *WebhookNotifier {
}
}
// severityRank returns numeric rank for severity comparison.
func severityRank(s domsoc.EventSeverity) int {
switch s {
case domsoc.SeverityCritical:
return 5
case domsoc.SeverityHigh:
return 4
case domsoc.SeverityMedium:
return 3
case domsoc.SeverityLow:
return 2
case domsoc.SeverityInfo:
return 1
default:
return 0
}
}
// NotifyIncident sends an incident webhook to all configured endpoints.
// Non-blocking: fires goroutines for each endpoint.
@ -105,9 +89,9 @@ func (w *WebhookNotifier) NotifyIncident(eventType string, incident *domsoc.Inci
return nil
}
// Severity filter
// Severity filter — use domain Rank() method (Q-1 FIX: removed duplicate severityRank).
if w.config.MinSeverity != "" {
if severityRank(incident.Severity) < severityRank(w.config.MinSeverity) {
if incident.Severity.Rank() < w.config.MinSeverity.Rank() {
return nil
}
}
@ -146,9 +130,9 @@ func (w *WebhookNotifier) NotifyIncident(eventType string, incident *domsoc.Inci
w.mu.Lock()
for _, r := range results {
if r.Success {
w.Sent++
w.sent++
} else {
w.Failed++
w.failed++
}
}
w.mu.Unlock()
@ -213,7 +197,7 @@ func (w *WebhookNotifier) sendWithRetry(url string, body []byte) WebhookResult {
result.Error = err.Error()
if attempt < w.config.MaxRetries {
backoff := time.Duration(1<<uint(attempt)) * 500 * time.Millisecond
jitter := time.Duration(rand.Intn(500)) * time.Millisecond
jitter := time.Duration(rand.IntN(500)) * time.Millisecond
time.Sleep(backoff + jitter)
continue
}
@ -230,12 +214,12 @@ func (w *WebhookNotifier) sendWithRetry(url string, body []byte) WebhookResult {
result.Error = fmt.Sprintf("HTTP %d", resp.StatusCode)
if attempt < w.config.MaxRetries {
backoff := time.Duration(1<<uint(attempt)) * 500 * time.Millisecond
jitter := time.Duration(rand.Intn(500)) * time.Millisecond
jitter := time.Duration(rand.IntN(500)) * time.Millisecond
time.Sleep(backoff + jitter)
}
}
log.Printf("[SOC] webhook failed after %d retries: %s → %s", w.config.MaxRetries, url, result.Error)
slog.Error("webhook failed", "retries", w.config.MaxRetries, "url", url, "error", result.Error)
return result
}
@ -243,5 +227,5 @@ func (w *WebhookNotifier) sendWithRetry(url string, body []byte) WebhookResult {
func (w *WebhookNotifier) Stats() (sent, failed int64) {
w.mu.RLock()
defer w.mu.RUnlock()
return w.Sent, w.Failed
return w.sent, w.failed
}