package shadow_ai import ( "context" "fmt" "strings" "sync" "testing" "time" ) // --- Mock Plugins --- type mockFirewall struct { blockIPs []string blockDomains []string healthy bool mu sync.Mutex } func newMockFirewall(healthy bool) *mockFirewall { return &mockFirewall{healthy: healthy} } func (m *mockFirewall) BlockIP(_ context.Context, ip string, _ time.Duration, _ string) error { m.mu.Lock() defer m.mu.Unlock() m.blockIPs = append(m.blockIPs, ip) return nil } func (m *mockFirewall) BlockDomain(_ context.Context, domain string, _ string) error { m.mu.Lock() defer m.mu.Unlock() m.blockDomains = append(m.blockDomains, domain) return nil } func (m *mockFirewall) UnblockIP(_ context.Context, _ string) error { return nil } func (m *mockFirewall) UnblockDomain(_ context.Context, _ string) error { return nil } func (m *mockFirewall) HealthCheck(_ context.Context) error { if !m.healthy { return fmt.Errorf("firewall offline") } return nil } func (m *mockFirewall) Vendor() string { return "mock-firewall" } type mockEDR struct { isolated []string healthy bool } func newMockEDR(healthy bool) *mockEDR { return &mockEDR{healthy: healthy} } func (m *mockEDR) IsolateHost(_ context.Context, hostname string) error { m.isolated = append(m.isolated, hostname) return nil } func (m *mockEDR) ReleaseHost(_ context.Context, _ string) error { return nil } func (m *mockEDR) KillProcess(_ context.Context, _ string, _ int) error { return nil } func (m *mockEDR) QuarantineFile(_ context.Context, _ string, _ string) error { return nil } func (m *mockEDR) HealthCheck(_ context.Context) error { if !m.healthy { return fmt.Errorf("EDR offline") } return nil } func (m *mockEDR) Vendor() string { return "mock-edr" } type mockGateway struct { blockedURLs []string healthy bool } func newMockGateway(healthy bool) *mockGateway { return &mockGateway{healthy: healthy} } func (m *mockGateway) BlockURL(_ context.Context, url string, _ string) error { m.blockedURLs = append(m.blockedURLs, url) return nil } func (m *mockGateway) UnblockURL(_ context.Context, _ string) error { return nil } func (m *mockGateway) BlockCategory(_ context.Context, _ string) error { return nil } func (m *mockGateway) HealthCheck(_ context.Context) error { if !m.healthy { return fmt.Errorf("gateway offline") } return nil } func (m *mockGateway) Vendor() string { return "mock-gateway" } // --- Registry Tests --- func TestRegistry_RegisterAndGet(t *testing.T) { reg := NewPluginRegistry() fw := newMockFirewall(true) reg.RegisterFactory(PluginTypeFirewall, "mock-firewall", func() interface{} { return fw }) cfg := &IntegrationConfig{ Plugins: []PluginConfig{ {Type: PluginTypeFirewall, Vendor: "mock-firewall", Enabled: true}, }, } if err := reg.LoadPlugins(cfg); err != nil { t.Fatalf("LoadPlugins: %v", err) } if reg.PluginCount() != 1 { t.Fatalf("expected 1 plugin, got %d", reg.PluginCount()) } got, ok := reg.Get("mock-firewall") if !ok { t.Fatal("plugin not found") } ne, ok := got.(NetworkEnforcer) if !ok { t.Fatal("plugin does not implement NetworkEnforcer") } if ne.Vendor() != "mock-firewall" { t.Fatalf("expected vendor mock-firewall, got %s", ne.Vendor()) } } func TestRegistry_DisabledPlugin(t *testing.T) { reg := NewPluginRegistry() reg.RegisterFactory(PluginTypeFirewall, "disabled-fw", func() interface{} { return newMockFirewall(true) }) cfg := &IntegrationConfig{ Plugins: []PluginConfig{ {Type: PluginTypeFirewall, Vendor: "disabled-fw", Enabled: false}, }, } _ = reg.LoadPlugins(cfg) if reg.PluginCount() != 0 { t.Fatalf("disabled plugin should not be loaded, got %d", reg.PluginCount()) } } func TestRegistry_MissingFactory(t *testing.T) { reg := NewPluginRegistry() cfg := &IntegrationConfig{ Plugins: []PluginConfig{ {Type: PluginTypeFirewall, Vendor: "non-existent", Enabled: true}, }, } _ = reg.LoadPlugins(cfg) if reg.PluginCount() != 0 { t.Fatalf("expected 0 plugins, got %d", reg.PluginCount()) } } func TestRegistry_GetByType(t *testing.T) { reg := NewPluginRegistry() reg.RegisterFactory(PluginTypeFirewall, "fw1", func() interface{} { return newMockFirewall(true) }) reg.RegisterFactory(PluginTypeEDR, "edr1", func() interface{} { return newMockEDR(true) }) cfg := &IntegrationConfig{ Plugins: []PluginConfig{ {Type: PluginTypeFirewall, Vendor: "fw1", Enabled: true}, {Type: PluginTypeEDR, Vendor: "edr1", Enabled: true}, }, } _ = reg.LoadPlugins(cfg) firewalls := reg.GetByType(PluginTypeFirewall) if len(firewalls) != 1 { t.Fatalf("expected 1 firewall, got %d", len(firewalls)) } edrs := reg.GetByType(PluginTypeEDR) if len(edrs) != 1 { t.Fatalf("expected 1 edr, got %d", len(edrs)) } } func TestRegistry_TypedGetters(t *testing.T) { reg := NewPluginRegistry() reg.RegisterFactory(PluginTypeFirewall, "fw1", func() interface{} { return newMockFirewall(true) }) reg.RegisterFactory(PluginTypeEDR, "edr1", func() interface{} { return newMockEDR(true) }) reg.RegisterFactory(PluginTypeProxy, "gw1", func() interface{} { return newMockGateway(true) }) cfg := &IntegrationConfig{ Plugins: []PluginConfig{ {Type: PluginTypeFirewall, Vendor: "fw1", Enabled: true}, {Type: PluginTypeEDR, Vendor: "edr1", Enabled: true}, {Type: PluginTypeProxy, Vendor: "gw1", Enabled: true}, }, } _ = reg.LoadPlugins(cfg) if len(reg.GetNetworkEnforcers()) != 1 { t.Fatal("expected 1 NetworkEnforcer") } if len(reg.GetEndpointControllers()) != 1 { t.Fatal("expected 1 EndpointController") } if len(reg.GetWebGateways()) != 1 { t.Fatal("expected 1 WebGateway") } } func TestRegistry_Vendors(t *testing.T) { reg := NewPluginRegistry() reg.RegisterFactory(PluginTypeFirewall, "a", func() interface{} { return newMockFirewall(true) }) reg.RegisterFactory(PluginTypeEDR, "b", func() interface{} { return newMockEDR(true) }) cfg := &IntegrationConfig{ Plugins: []PluginConfig{ {Type: PluginTypeFirewall, Vendor: "a", Enabled: true}, {Type: PluginTypeEDR, Vendor: "b", Enabled: true}, }, } _ = reg.LoadPlugins(cfg) vendors := reg.Vendors() if len(vendors) != 2 { t.Fatalf("expected 2 vendors, got %d", len(vendors)) } } // --- Health Tests --- func TestHealth_PluginHealthy(t *testing.T) { reg := NewPluginRegistry() fw := newMockFirewall(true) reg.RegisterFactory(PluginTypeFirewall, "fw", func() interface{} { return fw }) cfg := &IntegrationConfig{ Plugins: []PluginConfig{ {Type: PluginTypeFirewall, Vendor: "fw", Enabled: true}, }, } _ = reg.LoadPlugins(cfg) hc := NewHealthChecker(reg, time.Second, nil) hc.CheckNow(context.Background()) h, ok := reg.GetHealth("fw") if !ok || h.Status != PluginStatusHealthy { t.Fatalf("expected healthy, got %v", h) } } func TestHealth_PluginOffline(t *testing.T) { reg := NewPluginRegistry() fw := newMockFirewall(false) // unhealthy reg.RegisterFactory(PluginTypeFirewall, "fw", func() interface{} { return fw }) cfg := &IntegrationConfig{ Plugins: []PluginConfig{ {Type: PluginTypeFirewall, Vendor: "fw", Enabled: true}, }, } _ = reg.LoadPlugins(cfg) hc := NewHealthChecker(reg, time.Second, nil) // Check 3 times to trigger offline. for i := 0; i < MaxConsecutivePluginFailures; i++ { hc.CheckNow(context.Background()) } h, ok := reg.GetHealth("fw") if !ok { t.Fatal("health not found") } if h.Status != PluginStatusOffline { t.Fatalf("expected offline, got %s (consecutive=%d)", h.Status, h.Consecutive) } } func TestHealth_PluginRecovery(t *testing.T) { reg := NewPluginRegistry() fw := newMockFirewall(false) // start unhealthy reg.RegisterFactory(PluginTypeFirewall, "fw", func() interface{} { return fw }) cfg := &IntegrationConfig{ Plugins: []PluginConfig{ {Type: PluginTypeFirewall, Vendor: "fw", Enabled: true}, }, } _ = reg.LoadPlugins(cfg) var alerts []string hc := NewHealthChecker(reg, time.Second, func(vendor string, status PluginStatus, msg string) { alerts = append(alerts, fmt.Sprintf("%s:%s", vendor, status)) }) // Make it go offline. for i := 0; i < MaxConsecutivePluginFailures; i++ { hc.CheckNow(context.Background()) } // Now recover. fw.healthy = true hc.CheckNow(context.Background()) h, _ := reg.GetHealth("fw") if h.Status != PluginStatusHealthy { t.Fatalf("expected healthy after recovery, got %s", h.Status) } if len(alerts) < 2 { t.Fatalf("expected at least 2 alerts (offline + recovery), got %d", len(alerts)) } } // --- Fallback Tests --- func TestFallback_BlockDomain_Healthy(t *testing.T) { reg := NewPluginRegistry() fw := newMockFirewall(true) reg.RegisterFactory(PluginTypeFirewall, "mock-firewall", func() interface{} { return fw }) cfg := &IntegrationConfig{ Plugins: []PluginConfig{ {Type: PluginTypeFirewall, Vendor: "mock-firewall", Enabled: true}, }, } _ = reg.LoadPlugins(cfg) fm := NewFallbackManager(reg, "detect_only") vendor, err := fm.BlockDomain(context.Background(), "api.openai.com", "test") if err != nil { t.Fatalf("BlockDomain: %v", err) } if vendor != "mock-firewall" { t.Fatalf("expected vendor 'mock-firewall', got '%s'", vendor) } fw.mu.Lock() if len(fw.blockDomains) != 1 || fw.blockDomains[0] != "api.openai.com" { t.Fatalf("expected blocked domain, got %v", fw.blockDomains) } fw.mu.Unlock() } func TestFallback_AllOffline_DetectOnly(t *testing.T) { reg := NewPluginRegistry() fw := newMockFirewall(false) // offline reg.RegisterFactory(PluginTypeFirewall, "fw", func() interface{} { return fw }) cfg := &IntegrationConfig{ Plugins: []PluginConfig{ {Type: PluginTypeFirewall, Vendor: "fw", Enabled: true}, }, } _ = reg.LoadPlugins(cfg) // Mark as offline. reg.SetHealth("fw", &PluginHealth{Vendor: "fw", Status: PluginStatusOffline}) var detected []ShadowAIEvent fm := NewFallbackManager(reg, "detect_only") fm.SetEventLogger(func(e ShadowAIEvent) { detected = append(detected, e) }) vendor, err := fm.BlockDomain(context.Background(), "api.openai.com", "test") if err != nil { t.Fatalf("unexpected error: %v", err) } if vendor != "" { t.Fatalf("expected empty vendor for detect_only, got '%s'", vendor) } if len(detected) != 1 { t.Fatalf("expected 1 detect_only event, got %d", len(detected)) } if detected[0].Action != "detect_only" { t.Fatalf("expected action 'detect_only', got '%s'", detected[0].Action) } } func TestFallback_IsolateHost(t *testing.T) { reg := NewPluginRegistry() edr := newMockEDR(true) reg.RegisterFactory(PluginTypeEDR, "mock-edr", func() interface{} { return edr }) cfg := &IntegrationConfig{ Plugins: []PluginConfig{ {Type: PluginTypeEDR, Vendor: "mock-edr", Enabled: true}, }, } _ = reg.LoadPlugins(cfg) fm := NewFallbackManager(reg, "detect_only") vendor, err := fm.IsolateHost(context.Background(), "workstation-1") if err != nil { t.Fatalf("IsolateHost: %v", err) } if vendor != "mock-edr" { t.Fatalf("expected vendor 'mock-edr', got '%s'", vendor) } if len(edr.isolated) != 1 || edr.isolated[0] != "workstation-1" { t.Fatalf("host not isolated: %v", edr.isolated) } } // --- Detection Tests --- func TestDetection_MatchDomain(t *testing.T) { db := NewAISignatureDB() tests := []struct { domain string service string }{ {"chat.openai.com", "ChatGPT"}, {"api.openai.com", "ChatGPT"}, {"claude.ai", "Claude"}, {"api.anthropic.com", "Claude"}, {"gemini.google.com", "Gemini"}, {"api.deepseek.com", "DeepSeek"}, {"api.mistral.ai", "Mistral"}, {"api.groq.com", "Groq"}, {"example.com", ""}, {"google.com", ""}, } for _, tt := range tests { t.Run(tt.domain, func(t *testing.T) { result := db.MatchDomain(tt.domain) if result != tt.service { t.Errorf("MatchDomain(%q) = %q, want %q", tt.domain, result, tt.service) } }) } } func TestDetection_ServiceCount(t *testing.T) { db := NewAISignatureDB() if db.ServiceCount() < 30 { t.Fatalf("expected at least 30 AI services, got %d", db.ServiceCount()) } if db.DomainPatternCount() < 50 { t.Fatalf("expected at least 50 domain patterns, got %d", db.DomainPatternCount()) } } func TestDetection_AddCustomService(t *testing.T) { db := NewAISignatureDB() initial := db.ServiceCount() db.AddService(AIServiceInfo{ Name: "InternalLLM", Vendor: "Internal", Domains: []string{"llm.internal.corp"}, }) if db.ServiceCount() != initial+1 { t.Fatal("service not added") } result := db.MatchDomain("llm.internal.corp") if result != "InternalLLM" { t.Fatalf("custom service not matched: got %q", result) } } func TestDetection_ScanAPIKey_OpenAI(t *testing.T) { db := NewAISignatureDB() // Generate a mock key that matches the pattern. content := "My key is sk-proj-abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMN" result := db.ScanForAPIKeys(content) if result != "OpenAI Project Key" { t.Fatalf("expected OpenAI Project Key detection, got %q", result) } } func TestDetection_ScanAPIKey_NoMatch(t *testing.T) { db := NewAISignatureDB() result := db.ScanForAPIKeys("this is normal text without any API keys") if result != "" { t.Fatalf("expected no match, got %q", result) } } func TestDetection_MatchHTTPHeaders(t *testing.T) { db := NewAISignatureDB() headers := map[string]string{ "Authorization": "Bearer sk-abc123", } result := db.MatchHTTPHeaders(headers) if result != "authorization: bearer sk-" { t.Fatalf("expected OpenAI header match, got %q", result) } } func TestDetection_MatchHTTPHeaders_NoMatch(t *testing.T) { db := NewAISignatureDB() headers := map[string]string{ "Authorization": "Bearer jwt-token-xyz", } result := db.MatchHTTPHeaders(headers) if result != "" { t.Fatalf("expected no match, got %q", result) } } func TestDetection_NetworkDetector(t *testing.T) { nd := NewNetworkDetector() event := NetworkEvent{ User: "user1", Hostname: "ws-001", Destination: "api.openai.com", DataSize: 1024, Timestamp: time.Now(), } detected := nd.Analyze(event) if detected == nil { t.Fatal("expected detection for api.openai.com") } if detected.AIService != "ChatGPT" { t.Fatalf("expected ChatGPT, got %s", detected.AIService) } if detected.DetectionMethod != DetectNetwork { t.Fatalf("expected network detection, got %s", detected.DetectionMethod) } } func TestDetection_NetworkDetector_NoMatch(t *testing.T) { nd := NewNetworkDetector() event := NetworkEvent{ User: "user1", Destination: "example.com", Timestamp: time.Now(), } if nd.Analyze(event) != nil { t.Fatal("should not detect non-AI domain") } } func TestDetection_HTTPSignature(t *testing.T) { nd := NewNetworkDetector() event := NetworkEvent{ User: "user1", Destination: "some-proxy.corp.internal", HTTPHeaders: map[string]string{ "Authorization": "Bearer sk-abc123def456", }, Timestamp: time.Now(), } detected := nd.Analyze(event) if detected == nil { t.Fatal("expected detection via HTTP sig") } if detected.DetectionMethod != DetectHTTP { t.Fatalf("expected HTTP detection, got %s", detected.DetectionMethod) } } // --- Behavioral Tests --- func TestBehavioral_FirstAccess(t *testing.T) { bd := NewBehavioralDetector(10) bd.RecordAccess("user1", "api.openai.com", 1024) alerts := bd.DetectAnomalies() found := false for _, a := range alerts { if a.UserID == "user1" && a.AnomalyType == "first_ai_access" { found = true } } if !found { t.Fatal("expected first_ai_access alert for user without baseline") } } func TestBehavioral_AccessSpike(t *testing.T) { bd := NewBehavioralDetector(10) bd.SetBaseline("user1", &UserBehaviorProfile{ UserID: "user1", AccessFrequency: 5, }) // Record 50 accesses — 10x baseline. for i := 0; i < 50; i++ { bd.RecordAccess("user1", "api.openai.com", 100) } alerts := bd.DetectAnomalies() found := false for _, a := range alerts { if a.UserID == "user1" && a.AnomalyType == "access_spike" { found = true } } if !found { t.Fatal("expected access_spike alert") } } func TestBehavioral_NewDestination(t *testing.T) { bd := NewBehavioralDetector(10) bd.SetBaseline("user1", &UserBehaviorProfile{ UserID: "user1", AccessFrequency: 5, KnownDestinations: []string{"api.openai.com"}, }) bd.RecordAccess("user1", "api.anthropic.com", 100) alerts := bd.DetectAnomalies() found := false for _, a := range alerts { if a.UserID == "user1" && a.AnomalyType == "new_ai_destination" { found = true if a.Destination != "api.anthropic.com" { t.Fatalf("expected destination api.anthropic.com, got %s", a.Destination) } } } if !found { t.Fatal("expected new_ai_destination alert") } } func TestBehavioral_ResetCurrent(t *testing.T) { bd := NewBehavioralDetector(10) bd.RecordAccess("user1", "api.openai.com", 1024) bd.ResetCurrent() alerts := bd.DetectAnomalies() if len(alerts) != 0 { t.Fatalf("expected 0 alerts after reset, got %d", len(alerts)) } } // --- Controller Tests --- func TestController_ProcessNetworkEvent(t *testing.T) { ctrl := NewShadowAIController() var socEvents []string ctrl.SetSOCEventEmitter(func(source, severity, category, description string, meta map[string]string) { socEvents = append(socEvents, category+":"+description) }) event := NetworkEvent{ User: "user1", Hostname: "ws-001", Destination: "api.openai.com", DataSize: 2048, Timestamp: time.Now(), } detected := ctrl.ProcessNetworkEvent(context.Background(), event) if detected == nil { t.Fatal("expected detection") } if detected.AIService != "ChatGPT" { t.Fatalf("expected ChatGPT, got %s", detected.AIService) } if len(socEvents) != 1 { t.Fatalf("expected 1 SOC event, got %d", len(socEvents)) } } func TestController_GetStats(t *testing.T) { ctrl := NewShadowAIController() // Process a few events. for i := 0; i < 5; i++ { ctrl.ProcessNetworkEvent(context.Background(), NetworkEvent{ User: fmt.Sprintf("user%d", i%3), Destination: "api.openai.com", Timestamp: time.Now(), }) } stats := ctrl.GetStats("24h") if stats.Total != 5 { t.Fatalf("expected 5 total, got %d", stats.Total) } if stats.ByService["ChatGPT"] != 5 { t.Fatalf("expected 5 ChatGPT, got %d", stats.ByService["ChatGPT"]) } if len(stats.TopViolators) == 0 { t.Fatal("expected at least 1 violator") } } func TestController_GetEvents(t *testing.T) { ctrl := NewShadowAIController() for i := 0; i < 10; i++ { ctrl.ProcessNetworkEvent(context.Background(), NetworkEvent{ User: "user1", Destination: "api.openai.com", Timestamp: time.Now(), }) } events := ctrl.GetEvents(5) if len(events) != 5 { t.Fatalf("expected 5 events, got %d", len(events)) } } func TestController_ScanContent(t *testing.T) { ctrl := NewShadowAIController() result := ctrl.ScanContent("nothing here") if result != "" { t.Fatalf("expected no detection, got %q", result) } } func TestController_ComplianceReport(t *testing.T) { ctrl := NewShadowAIController() report := ctrl.GenerateComplianceReport("30d") if report.Period != "30d" { t.Fatalf("expected period 30d, got %s", report.Period) } if !report.AuditComplete { t.Fatal("expected audit complete") } if len(report.Regulations) != 3 { t.Fatalf("expected 3 regulations, got %d", len(report.Regulations)) } } func TestController_IntegrationHealth(t *testing.T) { ctrl := NewShadowAIController() health := ctrl.IntegrationHealth() if health == nil { t.Fatal("expected non-nil health") } } func TestServicesByCategory(t *testing.T) { categories := ServicesByCategory() if len(categories) == 0 { t.Fatal("expected categories") } if _, ok := categories["llm"]; !ok { t.Fatal("expected 'llm' category") } if _, ok := categories["code_assist"]; !ok { t.Fatal("expected 'code_assist' category") } } func TestController_EventBounded(t *testing.T) { ctrl := NewShadowAIController() // Override maxEvents for testing. ctrl.mu.Lock() ctrl.maxEvents = 10 ctrl.mu.Unlock() for i := 0; i < 20; i++ { ctrl.ProcessNetworkEvent(context.Background(), NetworkEvent{ User: "user1", Destination: "api.openai.com", Timestamp: time.Now(), }) } events := ctrl.GetEvents(100) if len(events) > 10 { t.Fatalf("expected max 10 events, got %d", len(events)) } } // ===================================================== // Phase 3: Document Review Bridge Tests // ===================================================== func TestDocBridge_CleanContent(t *testing.T) { db := NewDocBridge() result := db.ScanDocument("doc-1", "This is clean text without any PII or secrets.", "user1") if result.Status != DocReviewClean { t.Fatalf("expected clean, got %s", result.Status) } if result.DataClass != DataPublic { t.Fatalf("expected PUBLIC, got %s", result.DataClass) } if len(result.PIIFound) != 0 { t.Fatalf("expected 0 PII, got %d", len(result.PIIFound)) } } func TestDocBridge_DetectEmail(t *testing.T) { db := NewDocBridge() result := db.ScanDocument("doc-2", "Please contact john.doe@example.com for details.", "user1") if len(result.PIIFound) == 0 { t.Fatal("expected PII detection for email") } found := false for _, pii := range result.PIIFound { if pii.Type == "email" { found = true if pii.Masked == "" { t.Fatal("expected masked email") } } } if !found { t.Fatal("expected email PII type") } } func TestDocBridge_DetectSSN(t *testing.T) { db := NewDocBridge() result := db.ScanDocument("doc-3", "SSN: 123-45-6789", "user1") found := false for _, pii := range result.PIIFound { if pii.Type == "ssn" { found = true if pii.Masked != "***-**-****" { t.Fatalf("expected masked SSN, got %q", pii.Masked) } } } if !found { t.Fatal("expected SSN detection") } if result.DataClass != DataCritical { t.Fatalf("SSN should classify as CRITICAL, got %s", result.DataClass) } } func TestDocBridge_DetectCreditCard(t *testing.T) { db := NewDocBridge() result := db.ScanDocument("doc-4", "Card: 4111 1111 1111 1111", "user1") found := false for _, pii := range result.PIIFound { if pii.Type == "credit_card" { found = true } } if !found { t.Fatal("expected credit_card detection") } if result.DataClass != DataCritical { t.Fatalf("credit card should classify as CRITICAL, got %s", result.DataClass) } } func TestDocBridge_DetectAWSKey(t *testing.T) { db := NewDocBridge() result := db.ScanDocument("doc-5", "AWS key: AKIAIOSFODNN7EXAMPLE", "user1") if len(result.SecretsFound) == 0 { t.Fatal("expected AWS key detection") } found := false for _, s := range result.SecretsFound { if s.Provider == "AWS" { found = true } } if !found { t.Fatal("expected AWS provider") } if result.Status != DocReviewBlocked { t.Fatalf("secrets should block, got %s", result.Status) } } func TestDocBridge_DetectGitHubToken(t *testing.T) { db := NewDocBridge() result := db.ScanDocument("doc-6", "token: ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghij", "user1") found := false for _, s := range result.SecretsFound { if s.Provider == "GitHub" { found = true } } if !found { t.Fatal("expected GitHub token detection") } } func TestDocBridge_RedactContent(t *testing.T) { db := NewDocBridge() content := "Email: john@example.com, SSN: 123-45-6789, Key: AKIAIOSFODNN7EXAMPLE" redacted := db.RedactContent(content) if redacted == content { t.Fatal("expected content to be modified") } // Email should be partially masked. if strings.Contains(redacted, "john@example.com") { t.Fatal("email should be redacted") } // SSN should be fully masked. if strings.Contains(redacted, "123-45-6789") { t.Fatal("SSN should be redacted") } // AWS key should be replaced. if strings.Contains(redacted, "AKIAIOSFODNN7EXAMPLE") { t.Fatal("AWS key should be redacted") } } func TestDocBridge_GetReview(t *testing.T) { db := NewDocBridge() db.ScanDocument("doc-7", "clean text", "user1") r, ok := db.GetReview("doc-7") if !ok || r == nil { t.Fatal("review not found") } if r.DocumentID != "doc-7" { t.Fatalf("expected doc-7, got %s", r.DocumentID) } } func TestDocBridge_Stats(t *testing.T) { db := NewDocBridge() db.ScanDocument("d1", "clean text", "u1") db.ScanDocument("d2", "email: a@b.com", "u1") db.ScanDocument("d3", "key: AKIAIOSFODNN7EXAMPLE", "u1") stats := db.Stats() if stats["total"] != 3 { t.Fatalf("expected 3 total, got %d", stats["total"]) } if stats["clean"] != 1 { t.Fatalf("expected 1 clean, got %d", stats["clean"]) } } // ===================================================== // Phase 3: Approval Engine Tests // ===================================================== func TestApproval_AutoApprove_Public(t *testing.T) { ae := NewApprovalEngine() req := ae.SubmitRequest("user1", "doc-1", DataPublic) if req.Status != string(ApprovalAutoApproved) { t.Fatalf("expected auto_approved for PUBLIC, got %s", req.Status) } if req.ApprovedBy != "system" { t.Fatalf("expected approved by system, got %s", req.ApprovedBy) } } func TestApproval_PendingInternal(t *testing.T) { ae := NewApprovalEngine() req := ae.SubmitRequest("user1", "doc-2", DataInternal) if req.Status != string(ApprovalPending) { t.Fatalf("expected pending for INTERNAL, got %s", req.Status) } if req.ExpiresAt.IsZero() { t.Fatal("expected non-zero expiry for INTERNAL") } } func TestApproval_ApproveFlow(t *testing.T) { ae := NewApprovalEngine() req := ae.SubmitRequest("user1", "doc-3", DataConfidential) if err := ae.Approve(req.ID, "manager1"); err != nil { t.Fatalf("approve: %v", err) } got, ok := ae.GetRequest(req.ID) if !ok { t.Fatal("request not found after approval") } if got.Status != string(ApprovalApproved) { t.Fatalf("expected approved, got %s", got.Status) } if got.ApprovedBy != "manager1" { t.Fatalf("expected manager1, got %s", got.ApprovedBy) } } func TestApproval_DenyFlow(t *testing.T) { ae := NewApprovalEngine() req := ae.SubmitRequest("user1", "doc-4", DataCritical) if err := ae.Deny(req.ID, "ciso", "data too sensitive"); err != nil { t.Fatalf("deny: %v", err) } got, _ := ae.GetRequest(req.ID) if got.Status != string(ApprovalDenied) { t.Fatalf("expected denied, got %s", got.Status) } if got.Reason != "data too sensitive" { t.Fatalf("expected reason, got %q", got.Reason) } } func TestApproval_DoubleApprove(t *testing.T) { ae := NewApprovalEngine() req := ae.SubmitRequest("user1", "doc-5", DataInternal) _ = ae.Approve(req.ID, "mgr") err := ae.Approve(req.ID, "mgr2") if err == nil { t.Fatal("expected error on double approve") } } func TestApproval_ExpireOverdue(t *testing.T) { ae := NewApprovalEngine() req := ae.SubmitRequest("user1", "doc-6", DataInternal) // Manually set ExpiresAt to the past. ae.mu.Lock() ae.requests[req.ID].ExpiresAt = time.Now().Add(-1 * time.Hour) ae.mu.Unlock() expired := ae.ExpireOverdue() if expired != 1 { t.Fatalf("expected 1 expired, got %d", expired) } got, _ := ae.GetRequest(req.ID) if got.Status != string(ApprovalExpired) { t.Fatalf("expected expired, got %s", got.Status) } } func TestApproval_Stats(t *testing.T) { ae := NewApprovalEngine() ae.SubmitRequest("u1", "d1", DataPublic) // auto ae.SubmitRequest("u2", "d2", DataInternal) // pending req := ae.SubmitRequest("u3", "d3", DataConfidential) // pending _ = ae.Deny(req.ID, "ciso", "no") stats := ae.Stats() if stats["auto_approved"] != 1 { t.Fatalf("expected 1 auto_approved, got %d", stats["auto_approved"]) } if stats["pending"] != 1 { t.Fatalf("expected 1 pending, got %d", stats["pending"]) } if stats["denied"] != 1 { t.Fatalf("expected 1 denied, got %d", stats["denied"]) } } func TestApproval_Tiers(t *testing.T) { ae := NewApprovalEngine() tiers := ae.Tiers() if len(tiers) != 4 { t.Fatalf("expected 4 tiers, got %d", len(tiers)) } } // ===================================================== // Phase 3: Vendor Plugin Stubs // ===================================================== func TestPlugins_RegisterDefault(t *testing.T) { reg := NewPluginRegistry() RegisterDefaultPlugins(reg) // Provide required config for each vendor stub. cfg := &IntegrationConfig{ Plugins: []PluginConfig{ {Type: PluginTypeFirewall, Vendor: "checkpoint", Enabled: true, Config: map[string]interface{}{"api_url": "https://cp.local"}}, {Type: PluginTypeEDR, Vendor: "crowdstrike", Enabled: true, Config: map[string]interface{}{"client_id": "test-id"}}, {Type: PluginTypeProxy, Vendor: "zscaler", Enabled: true, Config: map[string]interface{}{"cloud_name": "zscaler.net"}}, }, } _ = reg.LoadPlugins(cfg) if reg.PluginCount() != 3 { t.Fatalf("expected 3 plugins, got %d", reg.PluginCount()) } } func TestPlugins_CheckPoint_Vendor(t *testing.T) { cp := NewCheckPointEnforcer() if cp.Vendor() != "checkpoint" { t.Fatalf("expected 'checkpoint', got %s", cp.Vendor()) } } func TestPlugins_CrowdStrike_Vendor(t *testing.T) { cs := NewCrowdStrikeController() if cs.Vendor() != "crowdstrike" { t.Fatalf("expected 'crowdstrike', got %s", cs.Vendor()) } } func TestPlugins_Zscaler_Vendor(t *testing.T) { z := NewZscalerGateway() if z.Vendor() != "zscaler" { t.Fatalf("expected 'zscaler', got %s", z.Vendor()) } } // ===================================================== // Phase 3: Correlation Rules // ===================================================== func TestCorrelation_RuleCount(t *testing.T) { rules := ShadowAICorrelationRules() if len(rules) != 9 { t.Fatalf("expected 9 correlation rules, got %d", len(rules)) } } func TestCorrelation_RuleIDs(t *testing.T) { rules := ShadowAICorrelationRules() ids := make(map[string]bool) for _, r := range rules { if ids[r.ID] { t.Fatalf("duplicate rule ID: %s", r.ID) } ids[r.ID] = true } } func TestCorrelation_CriticalRules(t *testing.T) { rules := ShadowAICorrelationRules() critical := 0 for _, r := range rules { if r.Severity == "CRITICAL" { critical++ } } if critical < 3 { t.Fatalf("expected at least 3 CRITICAL rules, got %d", critical) } } // ===================================================== // Phase 3: Controller Integration (DocBridge + Approval) // ===================================================== func TestController_ReviewDocument_Clean(t *testing.T) { ctrl := NewShadowAIController() result, approval := ctrl.ReviewDocument("doc-1", "clean content", "user1") if result.Status != DocReviewClean { t.Fatalf("expected clean, got %s", result.Status) } if approval == nil { t.Fatal("expected auto-approval for clean doc") } if approval.Status != string(ApprovalAutoApproved) { t.Fatalf("expected auto_approved, got %s", approval.Status) } } func TestController_ReviewDocument_WithPII(t *testing.T) { ctrl := NewShadowAIController() result, approval := ctrl.ReviewDocument("doc-2", "Contact: alice@corp.com", "user1") if result.Status != DocReviewRedacted { t.Fatalf("expected redacted, got %s", result.Status) } if approval == nil { t.Fatal("expected approval request for PII") } if approval.Status != string(ApprovalPending) { t.Fatalf("expected pending, got %s", approval.Status) } } func TestController_ReviewDocument_WithSecrets(t *testing.T) { ctrl := NewShadowAIController() result, approval := ctrl.ReviewDocument("doc-3", "key: AKIAIOSFODNN7EXAMPLE", "user1") if result.Status != DocReviewBlocked { t.Fatalf("expected blocked, got %s", result.Status) } // Should NOT create approval for blocked docs. if approval != nil { t.Fatal("blocked docs should not create approval") } }