mirror of
https://github.com/syntrex-lab/gomcp.git
synced 2026-04-24 20:06:21 +02:00
1228 lines
31 KiB
Go
1228 lines
31 KiB
Go
// Copyright 2026 Syntrex Lab. All rights reserved.
|
|
// Use of this source code is governed by an Apache-2.0 license
|
|
// that can be found in the LICENSE file.
|
|
|
|
package shadow_ai
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// --- Mock Plugins ---
|
|
|
|
type mockFirewall struct {
|
|
blockIPs []string
|
|
blockDomains []string
|
|
healthy bool
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func newMockFirewall(healthy bool) *mockFirewall {
|
|
return &mockFirewall{healthy: healthy}
|
|
}
|
|
|
|
func (m *mockFirewall) BlockIP(_ context.Context, ip string, _ time.Duration, _ string) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.blockIPs = append(m.blockIPs, ip)
|
|
return nil
|
|
}
|
|
|
|
func (m *mockFirewall) BlockDomain(_ context.Context, domain string, _ string) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.blockDomains = append(m.blockDomains, domain)
|
|
return nil
|
|
}
|
|
|
|
func (m *mockFirewall) UnblockIP(_ context.Context, _ string) error { return nil }
|
|
func (m *mockFirewall) UnblockDomain(_ context.Context, _ string) error { return nil }
|
|
|
|
func (m *mockFirewall) HealthCheck(_ context.Context) error {
|
|
if !m.healthy {
|
|
return fmt.Errorf("firewall offline")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *mockFirewall) Vendor() string { return "mock-firewall" }
|
|
|
|
type mockEDR struct {
|
|
isolated []string
|
|
healthy bool
|
|
}
|
|
|
|
func newMockEDR(healthy bool) *mockEDR {
|
|
return &mockEDR{healthy: healthy}
|
|
}
|
|
|
|
func (m *mockEDR) IsolateHost(_ context.Context, hostname string) error {
|
|
m.isolated = append(m.isolated, hostname)
|
|
return nil
|
|
}
|
|
func (m *mockEDR) ReleaseHost(_ context.Context, _ string) error { return nil }
|
|
func (m *mockEDR) KillProcess(_ context.Context, _ string, _ int) error { return nil }
|
|
func (m *mockEDR) QuarantineFile(_ context.Context, _ string, _ string) error { return nil }
|
|
|
|
func (m *mockEDR) HealthCheck(_ context.Context) error {
|
|
if !m.healthy {
|
|
return fmt.Errorf("EDR offline")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *mockEDR) Vendor() string { return "mock-edr" }
|
|
|
|
type mockGateway struct {
|
|
blockedURLs []string
|
|
healthy bool
|
|
}
|
|
|
|
func newMockGateway(healthy bool) *mockGateway {
|
|
return &mockGateway{healthy: healthy}
|
|
}
|
|
|
|
func (m *mockGateway) BlockURL(_ context.Context, url string, _ string) error {
|
|
m.blockedURLs = append(m.blockedURLs, url)
|
|
return nil
|
|
}
|
|
func (m *mockGateway) UnblockURL(_ context.Context, _ string) error { return nil }
|
|
func (m *mockGateway) BlockCategory(_ context.Context, _ string) error { return nil }
|
|
|
|
func (m *mockGateway) HealthCheck(_ context.Context) error {
|
|
if !m.healthy {
|
|
return fmt.Errorf("gateway offline")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *mockGateway) Vendor() string { return "mock-gateway" }
|
|
|
|
// --- Registry Tests ---
|
|
|
|
func TestRegistry_RegisterAndGet(t *testing.T) {
|
|
reg := NewPluginRegistry()
|
|
|
|
fw := newMockFirewall(true)
|
|
reg.RegisterFactory(PluginTypeFirewall, "mock-firewall", func() interface{} {
|
|
return fw
|
|
})
|
|
|
|
cfg := &IntegrationConfig{
|
|
Plugins: []PluginConfig{
|
|
{Type: PluginTypeFirewall, Vendor: "mock-firewall", Enabled: true},
|
|
},
|
|
}
|
|
|
|
if err := reg.LoadPlugins(cfg); err != nil {
|
|
t.Fatalf("LoadPlugins: %v", err)
|
|
}
|
|
|
|
if reg.PluginCount() != 1 {
|
|
t.Fatalf("expected 1 plugin, got %d", reg.PluginCount())
|
|
}
|
|
|
|
got, ok := reg.Get("mock-firewall")
|
|
if !ok {
|
|
t.Fatal("plugin not found")
|
|
}
|
|
|
|
ne, ok := got.(NetworkEnforcer)
|
|
if !ok {
|
|
t.Fatal("plugin does not implement NetworkEnforcer")
|
|
}
|
|
|
|
if ne.Vendor() != "mock-firewall" {
|
|
t.Fatalf("expected vendor mock-firewall, got %s", ne.Vendor())
|
|
}
|
|
}
|
|
|
|
func TestRegistry_DisabledPlugin(t *testing.T) {
|
|
reg := NewPluginRegistry()
|
|
reg.RegisterFactory(PluginTypeFirewall, "disabled-fw", func() interface{} {
|
|
return newMockFirewall(true)
|
|
})
|
|
|
|
cfg := &IntegrationConfig{
|
|
Plugins: []PluginConfig{
|
|
{Type: PluginTypeFirewall, Vendor: "disabled-fw", Enabled: false},
|
|
},
|
|
}
|
|
|
|
_ = reg.LoadPlugins(cfg)
|
|
|
|
if reg.PluginCount() != 0 {
|
|
t.Fatalf("disabled plugin should not be loaded, got %d", reg.PluginCount())
|
|
}
|
|
}
|
|
|
|
func TestRegistry_MissingFactory(t *testing.T) {
|
|
reg := NewPluginRegistry()
|
|
cfg := &IntegrationConfig{
|
|
Plugins: []PluginConfig{
|
|
{Type: PluginTypeFirewall, Vendor: "non-existent", Enabled: true},
|
|
},
|
|
}
|
|
|
|
_ = reg.LoadPlugins(cfg)
|
|
|
|
if reg.PluginCount() != 0 {
|
|
t.Fatalf("expected 0 plugins, got %d", reg.PluginCount())
|
|
}
|
|
}
|
|
|
|
func TestRegistry_GetByType(t *testing.T) {
|
|
reg := NewPluginRegistry()
|
|
|
|
reg.RegisterFactory(PluginTypeFirewall, "fw1", func() interface{} {
|
|
return newMockFirewall(true)
|
|
})
|
|
reg.RegisterFactory(PluginTypeEDR, "edr1", func() interface{} {
|
|
return newMockEDR(true)
|
|
})
|
|
|
|
cfg := &IntegrationConfig{
|
|
Plugins: []PluginConfig{
|
|
{Type: PluginTypeFirewall, Vendor: "fw1", Enabled: true},
|
|
{Type: PluginTypeEDR, Vendor: "edr1", Enabled: true},
|
|
},
|
|
}
|
|
|
|
_ = reg.LoadPlugins(cfg)
|
|
|
|
firewalls := reg.GetByType(PluginTypeFirewall)
|
|
if len(firewalls) != 1 {
|
|
t.Fatalf("expected 1 firewall, got %d", len(firewalls))
|
|
}
|
|
|
|
edrs := reg.GetByType(PluginTypeEDR)
|
|
if len(edrs) != 1 {
|
|
t.Fatalf("expected 1 edr, got %d", len(edrs))
|
|
}
|
|
}
|
|
|
|
func TestRegistry_TypedGetters(t *testing.T) {
|
|
reg := NewPluginRegistry()
|
|
|
|
reg.RegisterFactory(PluginTypeFirewall, "fw1", func() interface{} {
|
|
return newMockFirewall(true)
|
|
})
|
|
reg.RegisterFactory(PluginTypeEDR, "edr1", func() interface{} {
|
|
return newMockEDR(true)
|
|
})
|
|
reg.RegisterFactory(PluginTypeProxy, "gw1", func() interface{} {
|
|
return newMockGateway(true)
|
|
})
|
|
|
|
cfg := &IntegrationConfig{
|
|
Plugins: []PluginConfig{
|
|
{Type: PluginTypeFirewall, Vendor: "fw1", Enabled: true},
|
|
{Type: PluginTypeEDR, Vendor: "edr1", Enabled: true},
|
|
{Type: PluginTypeProxy, Vendor: "gw1", Enabled: true},
|
|
},
|
|
}
|
|
|
|
_ = reg.LoadPlugins(cfg)
|
|
|
|
if len(reg.GetNetworkEnforcers()) != 1 {
|
|
t.Fatal("expected 1 NetworkEnforcer")
|
|
}
|
|
if len(reg.GetEndpointControllers()) != 1 {
|
|
t.Fatal("expected 1 EndpointController")
|
|
}
|
|
if len(reg.GetWebGateways()) != 1 {
|
|
t.Fatal("expected 1 WebGateway")
|
|
}
|
|
}
|
|
|
|
func TestRegistry_Vendors(t *testing.T) {
|
|
reg := NewPluginRegistry()
|
|
reg.RegisterFactory(PluginTypeFirewall, "a", func() interface{} {
|
|
return newMockFirewall(true)
|
|
})
|
|
reg.RegisterFactory(PluginTypeEDR, "b", func() interface{} {
|
|
return newMockEDR(true)
|
|
})
|
|
|
|
cfg := &IntegrationConfig{
|
|
Plugins: []PluginConfig{
|
|
{Type: PluginTypeFirewall, Vendor: "a", Enabled: true},
|
|
{Type: PluginTypeEDR, Vendor: "b", Enabled: true},
|
|
},
|
|
}
|
|
|
|
_ = reg.LoadPlugins(cfg)
|
|
|
|
vendors := reg.Vendors()
|
|
if len(vendors) != 2 {
|
|
t.Fatalf("expected 2 vendors, got %d", len(vendors))
|
|
}
|
|
}
|
|
|
|
// --- Health Tests ---
|
|
|
|
func TestHealth_PluginHealthy(t *testing.T) {
|
|
reg := NewPluginRegistry()
|
|
fw := newMockFirewall(true)
|
|
reg.RegisterFactory(PluginTypeFirewall, "fw", func() interface{} { return fw })
|
|
|
|
cfg := &IntegrationConfig{
|
|
Plugins: []PluginConfig{
|
|
{Type: PluginTypeFirewall, Vendor: "fw", Enabled: true},
|
|
},
|
|
}
|
|
_ = reg.LoadPlugins(cfg)
|
|
|
|
hc := NewHealthChecker(reg, time.Second, nil)
|
|
hc.CheckNow(context.Background())
|
|
|
|
h, ok := reg.GetHealth("fw")
|
|
if !ok || h.Status != PluginStatusHealthy {
|
|
t.Fatalf("expected healthy, got %v", h)
|
|
}
|
|
}
|
|
|
|
func TestHealth_PluginOffline(t *testing.T) {
|
|
reg := NewPluginRegistry()
|
|
fw := newMockFirewall(false) // unhealthy
|
|
reg.RegisterFactory(PluginTypeFirewall, "fw", func() interface{} { return fw })
|
|
|
|
cfg := &IntegrationConfig{
|
|
Plugins: []PluginConfig{
|
|
{Type: PluginTypeFirewall, Vendor: "fw", Enabled: true},
|
|
},
|
|
}
|
|
_ = reg.LoadPlugins(cfg)
|
|
|
|
hc := NewHealthChecker(reg, time.Second, nil)
|
|
|
|
// Check 3 times to trigger offline.
|
|
for i := 0; i < MaxConsecutivePluginFailures; i++ {
|
|
hc.CheckNow(context.Background())
|
|
}
|
|
|
|
h, ok := reg.GetHealth("fw")
|
|
if !ok {
|
|
t.Fatal("health not found")
|
|
}
|
|
if h.Status != PluginStatusOffline {
|
|
t.Fatalf("expected offline, got %s (consecutive=%d)", h.Status, h.Consecutive)
|
|
}
|
|
}
|
|
|
|
func TestHealth_PluginRecovery(t *testing.T) {
|
|
reg := NewPluginRegistry()
|
|
fw := newMockFirewall(false) // start unhealthy
|
|
reg.RegisterFactory(PluginTypeFirewall, "fw", func() interface{} { return fw })
|
|
|
|
cfg := &IntegrationConfig{
|
|
Plugins: []PluginConfig{
|
|
{Type: PluginTypeFirewall, Vendor: "fw", Enabled: true},
|
|
},
|
|
}
|
|
_ = reg.LoadPlugins(cfg)
|
|
|
|
var alerts []string
|
|
hc := NewHealthChecker(reg, time.Second, func(vendor string, status PluginStatus, msg string) {
|
|
alerts = append(alerts, fmt.Sprintf("%s:%s", vendor, status))
|
|
})
|
|
|
|
// Make it go offline.
|
|
for i := 0; i < MaxConsecutivePluginFailures; i++ {
|
|
hc.CheckNow(context.Background())
|
|
}
|
|
|
|
// Now recover.
|
|
fw.healthy = true
|
|
hc.CheckNow(context.Background())
|
|
|
|
h, _ := reg.GetHealth("fw")
|
|
if h.Status != PluginStatusHealthy {
|
|
t.Fatalf("expected healthy after recovery, got %s", h.Status)
|
|
}
|
|
|
|
if len(alerts) < 2 {
|
|
t.Fatalf("expected at least 2 alerts (offline + recovery), got %d", len(alerts))
|
|
}
|
|
}
|
|
|
|
// --- Fallback Tests ---
|
|
|
|
func TestFallback_BlockDomain_Healthy(t *testing.T) {
|
|
reg := NewPluginRegistry()
|
|
fw := newMockFirewall(true)
|
|
reg.RegisterFactory(PluginTypeFirewall, "mock-firewall", func() interface{} { return fw })
|
|
|
|
cfg := &IntegrationConfig{
|
|
Plugins: []PluginConfig{
|
|
{Type: PluginTypeFirewall, Vendor: "mock-firewall", Enabled: true},
|
|
},
|
|
}
|
|
_ = reg.LoadPlugins(cfg)
|
|
|
|
fm := NewFallbackManager(reg, "detect_only")
|
|
vendor, err := fm.BlockDomain(context.Background(), "api.openai.com", "test")
|
|
if err != nil {
|
|
t.Fatalf("BlockDomain: %v", err)
|
|
}
|
|
if vendor != "mock-firewall" {
|
|
t.Fatalf("expected vendor 'mock-firewall', got '%s'", vendor)
|
|
}
|
|
|
|
fw.mu.Lock()
|
|
if len(fw.blockDomains) != 1 || fw.blockDomains[0] != "api.openai.com" {
|
|
t.Fatalf("expected blocked domain, got %v", fw.blockDomains)
|
|
}
|
|
fw.mu.Unlock()
|
|
}
|
|
|
|
func TestFallback_AllOffline_DetectOnly(t *testing.T) {
|
|
reg := NewPluginRegistry()
|
|
fw := newMockFirewall(false) // offline
|
|
reg.RegisterFactory(PluginTypeFirewall, "fw", func() interface{} { return fw })
|
|
|
|
cfg := &IntegrationConfig{
|
|
Plugins: []PluginConfig{
|
|
{Type: PluginTypeFirewall, Vendor: "fw", Enabled: true},
|
|
},
|
|
}
|
|
_ = reg.LoadPlugins(cfg)
|
|
|
|
// Mark as offline.
|
|
reg.SetHealth("fw", &PluginHealth{Vendor: "fw", Status: PluginStatusOffline})
|
|
|
|
var detected []ShadowAIEvent
|
|
fm := NewFallbackManager(reg, "detect_only")
|
|
fm.SetEventLogger(func(e ShadowAIEvent) {
|
|
detected = append(detected, e)
|
|
})
|
|
|
|
vendor, err := fm.BlockDomain(context.Background(), "api.openai.com", "test")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if vendor != "" {
|
|
t.Fatalf("expected empty vendor for detect_only, got '%s'", vendor)
|
|
}
|
|
|
|
if len(detected) != 1 {
|
|
t.Fatalf("expected 1 detect_only event, got %d", len(detected))
|
|
}
|
|
if detected[0].Action != "detect_only" {
|
|
t.Fatalf("expected action 'detect_only', got '%s'", detected[0].Action)
|
|
}
|
|
}
|
|
|
|
func TestFallback_IsolateHost(t *testing.T) {
|
|
reg := NewPluginRegistry()
|
|
edr := newMockEDR(true)
|
|
reg.RegisterFactory(PluginTypeEDR, "mock-edr", func() interface{} { return edr })
|
|
|
|
cfg := &IntegrationConfig{
|
|
Plugins: []PluginConfig{
|
|
{Type: PluginTypeEDR, Vendor: "mock-edr", Enabled: true},
|
|
},
|
|
}
|
|
_ = reg.LoadPlugins(cfg)
|
|
|
|
fm := NewFallbackManager(reg, "detect_only")
|
|
vendor, err := fm.IsolateHost(context.Background(), "workstation-1")
|
|
if err != nil {
|
|
t.Fatalf("IsolateHost: %v", err)
|
|
}
|
|
if vendor != "mock-edr" {
|
|
t.Fatalf("expected vendor 'mock-edr', got '%s'", vendor)
|
|
}
|
|
if len(edr.isolated) != 1 || edr.isolated[0] != "workstation-1" {
|
|
t.Fatalf("host not isolated: %v", edr.isolated)
|
|
}
|
|
}
|
|
|
|
// --- Detection Tests ---
|
|
|
|
func TestDetection_MatchDomain(t *testing.T) {
|
|
db := NewAISignatureDB()
|
|
|
|
tests := []struct {
|
|
domain string
|
|
service string
|
|
}{
|
|
{"chat.openai.com", "ChatGPT"},
|
|
{"api.openai.com", "ChatGPT"},
|
|
{"claude.ai", "Claude"},
|
|
{"api.anthropic.com", "Claude"},
|
|
{"gemini.google.com", "Gemini"},
|
|
{"api.deepseek.com", "DeepSeek"},
|
|
{"api.mistral.ai", "Mistral"},
|
|
{"api.groq.com", "Groq"},
|
|
{"example.com", ""},
|
|
{"google.com", ""},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.domain, func(t *testing.T) {
|
|
result := db.MatchDomain(tt.domain)
|
|
if result != tt.service {
|
|
t.Errorf("MatchDomain(%q) = %q, want %q", tt.domain, result, tt.service)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDetection_ServiceCount(t *testing.T) {
|
|
db := NewAISignatureDB()
|
|
if db.ServiceCount() < 30 {
|
|
t.Fatalf("expected at least 30 AI services, got %d", db.ServiceCount())
|
|
}
|
|
if db.DomainPatternCount() < 50 {
|
|
t.Fatalf("expected at least 50 domain patterns, got %d", db.DomainPatternCount())
|
|
}
|
|
}
|
|
|
|
func TestDetection_AddCustomService(t *testing.T) {
|
|
db := NewAISignatureDB()
|
|
initial := db.ServiceCount()
|
|
|
|
db.AddService(AIServiceInfo{
|
|
Name: "InternalLLM",
|
|
Vendor: "Internal",
|
|
Domains: []string{"llm.internal.corp"},
|
|
})
|
|
|
|
if db.ServiceCount() != initial+1 {
|
|
t.Fatal("service not added")
|
|
}
|
|
|
|
result := db.MatchDomain("llm.internal.corp")
|
|
if result != "InternalLLM" {
|
|
t.Fatalf("custom service not matched: got %q", result)
|
|
}
|
|
}
|
|
|
|
func TestDetection_ScanAPIKey_OpenAI(t *testing.T) {
|
|
db := NewAISignatureDB()
|
|
|
|
// Generate a mock key that matches the pattern.
|
|
content := "My key is sk-proj-abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMN"
|
|
result := db.ScanForAPIKeys(content)
|
|
if result != "OpenAI Project Key" {
|
|
t.Fatalf("expected OpenAI Project Key detection, got %q", result)
|
|
}
|
|
}
|
|
|
|
func TestDetection_ScanAPIKey_NoMatch(t *testing.T) {
|
|
db := NewAISignatureDB()
|
|
|
|
result := db.ScanForAPIKeys("this is normal text without any API keys")
|
|
if result != "" {
|
|
t.Fatalf("expected no match, got %q", result)
|
|
}
|
|
}
|
|
|
|
func TestDetection_MatchHTTPHeaders(t *testing.T) {
|
|
db := NewAISignatureDB()
|
|
|
|
headers := map[string]string{
|
|
"Authorization": "Bearer sk-abc123",
|
|
}
|
|
result := db.MatchHTTPHeaders(headers)
|
|
if result != "authorization: bearer sk-" {
|
|
t.Fatalf("expected OpenAI header match, got %q", result)
|
|
}
|
|
}
|
|
|
|
func TestDetection_MatchHTTPHeaders_NoMatch(t *testing.T) {
|
|
db := NewAISignatureDB()
|
|
|
|
headers := map[string]string{
|
|
"Authorization": "Bearer jwt-token-xyz",
|
|
}
|
|
result := db.MatchHTTPHeaders(headers)
|
|
if result != "" {
|
|
t.Fatalf("expected no match, got %q", result)
|
|
}
|
|
}
|
|
|
|
func TestDetection_NetworkDetector(t *testing.T) {
|
|
nd := NewNetworkDetector()
|
|
|
|
event := NetworkEvent{
|
|
User: "user1",
|
|
Hostname: "ws-001",
|
|
Destination: "api.openai.com",
|
|
DataSize: 1024,
|
|
Timestamp: time.Now(),
|
|
}
|
|
|
|
detected := nd.Analyze(event)
|
|
if detected == nil {
|
|
t.Fatal("expected detection for api.openai.com")
|
|
}
|
|
if detected.AIService != "ChatGPT" {
|
|
t.Fatalf("expected ChatGPT, got %s", detected.AIService)
|
|
}
|
|
if detected.DetectionMethod != DetectNetwork {
|
|
t.Fatalf("expected network detection, got %s", detected.DetectionMethod)
|
|
}
|
|
}
|
|
|
|
func TestDetection_NetworkDetector_NoMatch(t *testing.T) {
|
|
nd := NewNetworkDetector()
|
|
|
|
event := NetworkEvent{
|
|
User: "user1",
|
|
Destination: "example.com",
|
|
Timestamp: time.Now(),
|
|
}
|
|
|
|
if nd.Analyze(event) != nil {
|
|
t.Fatal("should not detect non-AI domain")
|
|
}
|
|
}
|
|
|
|
func TestDetection_HTTPSignature(t *testing.T) {
|
|
nd := NewNetworkDetector()
|
|
|
|
event := NetworkEvent{
|
|
User: "user1",
|
|
Destination: "some-proxy.corp.internal",
|
|
HTTPHeaders: map[string]string{
|
|
"Authorization": "Bearer sk-abc123def456",
|
|
},
|
|
Timestamp: time.Now(),
|
|
}
|
|
|
|
detected := nd.Analyze(event)
|
|
if detected == nil {
|
|
t.Fatal("expected detection via HTTP sig")
|
|
}
|
|
if detected.DetectionMethod != DetectHTTP {
|
|
t.Fatalf("expected HTTP detection, got %s", detected.DetectionMethod)
|
|
}
|
|
}
|
|
|
|
// --- Behavioral Tests ---
|
|
|
|
func TestBehavioral_FirstAccess(t *testing.T) {
|
|
bd := NewBehavioralDetector(10)
|
|
|
|
bd.RecordAccess("user1", "api.openai.com", 1024)
|
|
|
|
alerts := bd.DetectAnomalies()
|
|
found := false
|
|
for _, a := range alerts {
|
|
if a.UserID == "user1" && a.AnomalyType == "first_ai_access" {
|
|
found = true
|
|
}
|
|
}
|
|
if !found {
|
|
t.Fatal("expected first_ai_access alert for user without baseline")
|
|
}
|
|
}
|
|
|
|
func TestBehavioral_AccessSpike(t *testing.T) {
|
|
bd := NewBehavioralDetector(10)
|
|
|
|
bd.SetBaseline("user1", &UserBehaviorProfile{
|
|
UserID: "user1",
|
|
AccessFrequency: 5,
|
|
})
|
|
|
|
// Record 50 accesses — 10x baseline.
|
|
for i := 0; i < 50; i++ {
|
|
bd.RecordAccess("user1", "api.openai.com", 100)
|
|
}
|
|
|
|
alerts := bd.DetectAnomalies()
|
|
found := false
|
|
for _, a := range alerts {
|
|
if a.UserID == "user1" && a.AnomalyType == "access_spike" {
|
|
found = true
|
|
}
|
|
}
|
|
if !found {
|
|
t.Fatal("expected access_spike alert")
|
|
}
|
|
}
|
|
|
|
func TestBehavioral_NewDestination(t *testing.T) {
|
|
bd := NewBehavioralDetector(10)
|
|
|
|
bd.SetBaseline("user1", &UserBehaviorProfile{
|
|
UserID: "user1",
|
|
AccessFrequency: 5,
|
|
KnownDestinations: []string{"api.openai.com"},
|
|
})
|
|
|
|
bd.RecordAccess("user1", "api.anthropic.com", 100)
|
|
|
|
alerts := bd.DetectAnomalies()
|
|
found := false
|
|
for _, a := range alerts {
|
|
if a.UserID == "user1" && a.AnomalyType == "new_ai_destination" {
|
|
found = true
|
|
if a.Destination != "api.anthropic.com" {
|
|
t.Fatalf("expected destination api.anthropic.com, got %s", a.Destination)
|
|
}
|
|
}
|
|
}
|
|
if !found {
|
|
t.Fatal("expected new_ai_destination alert")
|
|
}
|
|
}
|
|
|
|
func TestBehavioral_ResetCurrent(t *testing.T) {
|
|
bd := NewBehavioralDetector(10)
|
|
bd.RecordAccess("user1", "api.openai.com", 1024)
|
|
bd.ResetCurrent()
|
|
|
|
alerts := bd.DetectAnomalies()
|
|
if len(alerts) != 0 {
|
|
t.Fatalf("expected 0 alerts after reset, got %d", len(alerts))
|
|
}
|
|
}
|
|
|
|
// --- Controller Tests ---
|
|
|
|
func TestController_ProcessNetworkEvent(t *testing.T) {
|
|
ctrl := NewShadowAIController()
|
|
|
|
var socEvents []string
|
|
ctrl.SetSOCEventEmitter(func(source, severity, category, description string, meta map[string]string) {
|
|
socEvents = append(socEvents, category+":"+description)
|
|
})
|
|
|
|
event := NetworkEvent{
|
|
User: "user1",
|
|
Hostname: "ws-001",
|
|
Destination: "api.openai.com",
|
|
DataSize: 2048,
|
|
Timestamp: time.Now(),
|
|
}
|
|
|
|
detected := ctrl.ProcessNetworkEvent(context.Background(), event)
|
|
if detected == nil {
|
|
t.Fatal("expected detection")
|
|
}
|
|
if detected.AIService != "ChatGPT" {
|
|
t.Fatalf("expected ChatGPT, got %s", detected.AIService)
|
|
}
|
|
if len(socEvents) != 1 {
|
|
t.Fatalf("expected 1 SOC event, got %d", len(socEvents))
|
|
}
|
|
}
|
|
|
|
func TestController_GetStats(t *testing.T) {
|
|
ctrl := NewShadowAIController()
|
|
|
|
// Process a few events.
|
|
for i := 0; i < 5; i++ {
|
|
ctrl.ProcessNetworkEvent(context.Background(), NetworkEvent{
|
|
User: fmt.Sprintf("user%d", i%3),
|
|
Destination: "api.openai.com",
|
|
Timestamp: time.Now(),
|
|
})
|
|
}
|
|
|
|
stats := ctrl.GetStats("24h")
|
|
if stats.Total != 5 {
|
|
t.Fatalf("expected 5 total, got %d", stats.Total)
|
|
}
|
|
if stats.ByService["ChatGPT"] != 5 {
|
|
t.Fatalf("expected 5 ChatGPT, got %d", stats.ByService["ChatGPT"])
|
|
}
|
|
if len(stats.TopViolators) == 0 {
|
|
t.Fatal("expected at least 1 violator")
|
|
}
|
|
}
|
|
|
|
func TestController_GetEvents(t *testing.T) {
|
|
ctrl := NewShadowAIController()
|
|
|
|
for i := 0; i < 10; i++ {
|
|
ctrl.ProcessNetworkEvent(context.Background(), NetworkEvent{
|
|
User: "user1",
|
|
Destination: "api.openai.com",
|
|
Timestamp: time.Now(),
|
|
})
|
|
}
|
|
|
|
events := ctrl.GetEvents(5)
|
|
if len(events) != 5 {
|
|
t.Fatalf("expected 5 events, got %d", len(events))
|
|
}
|
|
}
|
|
|
|
func TestController_ScanContent(t *testing.T) {
|
|
ctrl := NewShadowAIController()
|
|
|
|
result := ctrl.ScanContent("nothing here")
|
|
if result != "" {
|
|
t.Fatalf("expected no detection, got %q", result)
|
|
}
|
|
}
|
|
|
|
func TestController_ComplianceReport(t *testing.T) {
|
|
ctrl := NewShadowAIController()
|
|
|
|
report := ctrl.GenerateComplianceReport("30d")
|
|
if report.Period != "30d" {
|
|
t.Fatalf("expected period 30d, got %s", report.Period)
|
|
}
|
|
if !report.AuditComplete {
|
|
t.Fatal("expected audit complete")
|
|
}
|
|
if len(report.Regulations) != 3 {
|
|
t.Fatalf("expected 3 regulations, got %d", len(report.Regulations))
|
|
}
|
|
}
|
|
|
|
func TestController_IntegrationHealth(t *testing.T) {
|
|
ctrl := NewShadowAIController()
|
|
health := ctrl.IntegrationHealth()
|
|
if health == nil {
|
|
t.Fatal("expected non-nil health")
|
|
}
|
|
}
|
|
|
|
func TestServicesByCategory(t *testing.T) {
|
|
categories := ServicesByCategory()
|
|
if len(categories) == 0 {
|
|
t.Fatal("expected categories")
|
|
}
|
|
if _, ok := categories["llm"]; !ok {
|
|
t.Fatal("expected 'llm' category")
|
|
}
|
|
if _, ok := categories["code_assist"]; !ok {
|
|
t.Fatal("expected 'code_assist' category")
|
|
}
|
|
}
|
|
|
|
func TestController_EventBounded(t *testing.T) {
|
|
ctrl := NewShadowAIController()
|
|
// Override maxEvents for testing.
|
|
ctrl.mu.Lock()
|
|
ctrl.maxEvents = 10
|
|
ctrl.mu.Unlock()
|
|
|
|
for i := 0; i < 20; i++ {
|
|
ctrl.ProcessNetworkEvent(context.Background(), NetworkEvent{
|
|
User: "user1",
|
|
Destination: "api.openai.com",
|
|
Timestamp: time.Now(),
|
|
})
|
|
}
|
|
|
|
events := ctrl.GetEvents(100)
|
|
if len(events) > 10 {
|
|
t.Fatalf("expected max 10 events, got %d", len(events))
|
|
}
|
|
}
|
|
|
|
// =====================================================
|
|
// Phase 3: Document Review Bridge Tests
|
|
// =====================================================
|
|
|
|
func TestDocBridge_CleanContent(t *testing.T) {
|
|
db := NewDocBridge()
|
|
result := db.ScanDocument("doc-1", "This is clean text without any PII or secrets.", "user1")
|
|
if result.Status != DocReviewClean {
|
|
t.Fatalf("expected clean, got %s", result.Status)
|
|
}
|
|
if result.DataClass != DataPublic {
|
|
t.Fatalf("expected PUBLIC, got %s", result.DataClass)
|
|
}
|
|
if len(result.PIIFound) != 0 {
|
|
t.Fatalf("expected 0 PII, got %d", len(result.PIIFound))
|
|
}
|
|
}
|
|
|
|
func TestDocBridge_DetectEmail(t *testing.T) {
|
|
db := NewDocBridge()
|
|
result := db.ScanDocument("doc-2", "Please contact john.doe@example.com for details.", "user1")
|
|
if len(result.PIIFound) == 0 {
|
|
t.Fatal("expected PII detection for email")
|
|
}
|
|
found := false
|
|
for _, pii := range result.PIIFound {
|
|
if pii.Type == "email" {
|
|
found = true
|
|
if pii.Masked == "" {
|
|
t.Fatal("expected masked email")
|
|
}
|
|
}
|
|
}
|
|
if !found {
|
|
t.Fatal("expected email PII type")
|
|
}
|
|
}
|
|
|
|
func TestDocBridge_DetectSSN(t *testing.T) {
|
|
db := NewDocBridge()
|
|
result := db.ScanDocument("doc-3", "SSN: 123-45-6789", "user1")
|
|
found := false
|
|
for _, pii := range result.PIIFound {
|
|
if pii.Type == "ssn" {
|
|
found = true
|
|
if pii.Masked != "***-**-****" {
|
|
t.Fatalf("expected masked SSN, got %q", pii.Masked)
|
|
}
|
|
}
|
|
}
|
|
if !found {
|
|
t.Fatal("expected SSN detection")
|
|
}
|
|
if result.DataClass != DataCritical {
|
|
t.Fatalf("SSN should classify as CRITICAL, got %s", result.DataClass)
|
|
}
|
|
}
|
|
|
|
func TestDocBridge_DetectCreditCard(t *testing.T) {
|
|
db := NewDocBridge()
|
|
result := db.ScanDocument("doc-4", "Card: 4111 1111 1111 1111", "user1")
|
|
found := false
|
|
for _, pii := range result.PIIFound {
|
|
if pii.Type == "credit_card" {
|
|
found = true
|
|
}
|
|
}
|
|
if !found {
|
|
t.Fatal("expected credit_card detection")
|
|
}
|
|
if result.DataClass != DataCritical {
|
|
t.Fatalf("credit card should classify as CRITICAL, got %s", result.DataClass)
|
|
}
|
|
}
|
|
|
|
func TestDocBridge_DetectAWSKey(t *testing.T) {
|
|
db := NewDocBridge()
|
|
result := db.ScanDocument("doc-5", "AWS key: AKIAIOSFODNN7EXAMPLE", "user1")
|
|
if len(result.SecretsFound) == 0 {
|
|
t.Fatal("expected AWS key detection")
|
|
}
|
|
found := false
|
|
for _, s := range result.SecretsFound {
|
|
if s.Provider == "AWS" {
|
|
found = true
|
|
}
|
|
}
|
|
if !found {
|
|
t.Fatal("expected AWS provider")
|
|
}
|
|
if result.Status != DocReviewBlocked {
|
|
t.Fatalf("secrets should block, got %s", result.Status)
|
|
}
|
|
}
|
|
|
|
func TestDocBridge_DetectGitHubToken(t *testing.T) {
|
|
db := NewDocBridge()
|
|
result := db.ScanDocument("doc-6", "token: ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghij", "user1")
|
|
found := false
|
|
for _, s := range result.SecretsFound {
|
|
if s.Provider == "GitHub" {
|
|
found = true
|
|
}
|
|
}
|
|
if !found {
|
|
t.Fatal("expected GitHub token detection")
|
|
}
|
|
}
|
|
|
|
func TestDocBridge_RedactContent(t *testing.T) {
|
|
db := NewDocBridge()
|
|
content := "Email: john@example.com, SSN: 123-45-6789, Key: AKIAIOSFODNN7EXAMPLE"
|
|
redacted := db.RedactContent(content)
|
|
|
|
if redacted == content {
|
|
t.Fatal("expected content to be modified")
|
|
}
|
|
// Email should be partially masked.
|
|
if strings.Contains(redacted, "john@example.com") {
|
|
t.Fatal("email should be redacted")
|
|
}
|
|
// SSN should be fully masked.
|
|
if strings.Contains(redacted, "123-45-6789") {
|
|
t.Fatal("SSN should be redacted")
|
|
}
|
|
// AWS key should be replaced.
|
|
if strings.Contains(redacted, "AKIAIOSFODNN7EXAMPLE") {
|
|
t.Fatal("AWS key should be redacted")
|
|
}
|
|
}
|
|
|
|
func TestDocBridge_GetReview(t *testing.T) {
|
|
db := NewDocBridge()
|
|
db.ScanDocument("doc-7", "clean text", "user1")
|
|
|
|
r, ok := db.GetReview("doc-7")
|
|
if !ok || r == nil {
|
|
t.Fatal("review not found")
|
|
}
|
|
if r.DocumentID != "doc-7" {
|
|
t.Fatalf("expected doc-7, got %s", r.DocumentID)
|
|
}
|
|
}
|
|
|
|
func TestDocBridge_Stats(t *testing.T) {
|
|
db := NewDocBridge()
|
|
db.ScanDocument("d1", "clean text", "u1")
|
|
db.ScanDocument("d2", "email: a@b.com", "u1")
|
|
db.ScanDocument("d3", "key: AKIAIOSFODNN7EXAMPLE", "u1")
|
|
|
|
stats := db.Stats()
|
|
if stats["total"] != 3 {
|
|
t.Fatalf("expected 3 total, got %d", stats["total"])
|
|
}
|
|
if stats["clean"] != 1 {
|
|
t.Fatalf("expected 1 clean, got %d", stats["clean"])
|
|
}
|
|
}
|
|
|
|
// =====================================================
|
|
// Phase 3: Approval Engine Tests
|
|
// =====================================================
|
|
|
|
func TestApproval_AutoApprove_Public(t *testing.T) {
|
|
ae := NewApprovalEngine()
|
|
req := ae.SubmitRequest("user1", "doc-1", DataPublic)
|
|
if req.Status != string(ApprovalAutoApproved) {
|
|
t.Fatalf("expected auto_approved for PUBLIC, got %s", req.Status)
|
|
}
|
|
if req.ApprovedBy != "system" {
|
|
t.Fatalf("expected approved by system, got %s", req.ApprovedBy)
|
|
}
|
|
}
|
|
|
|
func TestApproval_PendingInternal(t *testing.T) {
|
|
ae := NewApprovalEngine()
|
|
req := ae.SubmitRequest("user1", "doc-2", DataInternal)
|
|
if req.Status != string(ApprovalPending) {
|
|
t.Fatalf("expected pending for INTERNAL, got %s", req.Status)
|
|
}
|
|
if req.ExpiresAt.IsZero() {
|
|
t.Fatal("expected non-zero expiry for INTERNAL")
|
|
}
|
|
}
|
|
|
|
func TestApproval_ApproveFlow(t *testing.T) {
|
|
ae := NewApprovalEngine()
|
|
req := ae.SubmitRequest("user1", "doc-3", DataConfidential)
|
|
|
|
if err := ae.Approve(req.ID, "manager1"); err != nil {
|
|
t.Fatalf("approve: %v", err)
|
|
}
|
|
|
|
got, ok := ae.GetRequest(req.ID)
|
|
if !ok {
|
|
t.Fatal("request not found after approval")
|
|
}
|
|
if got.Status != string(ApprovalApproved) {
|
|
t.Fatalf("expected approved, got %s", got.Status)
|
|
}
|
|
if got.ApprovedBy != "manager1" {
|
|
t.Fatalf("expected manager1, got %s", got.ApprovedBy)
|
|
}
|
|
}
|
|
|
|
func TestApproval_DenyFlow(t *testing.T) {
|
|
ae := NewApprovalEngine()
|
|
req := ae.SubmitRequest("user1", "doc-4", DataCritical)
|
|
|
|
if err := ae.Deny(req.ID, "ciso", "data too sensitive"); err != nil {
|
|
t.Fatalf("deny: %v", err)
|
|
}
|
|
|
|
got, _ := ae.GetRequest(req.ID)
|
|
if got.Status != string(ApprovalDenied) {
|
|
t.Fatalf("expected denied, got %s", got.Status)
|
|
}
|
|
if got.Reason != "data too sensitive" {
|
|
t.Fatalf("expected reason, got %q", got.Reason)
|
|
}
|
|
}
|
|
|
|
func TestApproval_DoubleApprove(t *testing.T) {
|
|
ae := NewApprovalEngine()
|
|
req := ae.SubmitRequest("user1", "doc-5", DataInternal)
|
|
_ = ae.Approve(req.ID, "mgr")
|
|
|
|
err := ae.Approve(req.ID, "mgr2")
|
|
if err == nil {
|
|
t.Fatal("expected error on double approve")
|
|
}
|
|
}
|
|
|
|
func TestApproval_ExpireOverdue(t *testing.T) {
|
|
ae := NewApprovalEngine()
|
|
req := ae.SubmitRequest("user1", "doc-6", DataInternal)
|
|
|
|
// Manually set ExpiresAt to the past.
|
|
ae.mu.Lock()
|
|
ae.requests[req.ID].ExpiresAt = time.Now().Add(-1 * time.Hour)
|
|
ae.mu.Unlock()
|
|
|
|
expired := ae.ExpireOverdue()
|
|
if expired != 1 {
|
|
t.Fatalf("expected 1 expired, got %d", expired)
|
|
}
|
|
|
|
got, _ := ae.GetRequest(req.ID)
|
|
if got.Status != string(ApprovalExpired) {
|
|
t.Fatalf("expected expired, got %s", got.Status)
|
|
}
|
|
}
|
|
|
|
func TestApproval_Stats(t *testing.T) {
|
|
ae := NewApprovalEngine()
|
|
ae.SubmitRequest("u1", "d1", DataPublic) // auto
|
|
ae.SubmitRequest("u2", "d2", DataInternal) // pending
|
|
req := ae.SubmitRequest("u3", "d3", DataConfidential) // pending
|
|
_ = ae.Deny(req.ID, "ciso", "no")
|
|
|
|
stats := ae.Stats()
|
|
if stats["auto_approved"] != 1 {
|
|
t.Fatalf("expected 1 auto_approved, got %d", stats["auto_approved"])
|
|
}
|
|
if stats["pending"] != 1 {
|
|
t.Fatalf("expected 1 pending, got %d", stats["pending"])
|
|
}
|
|
if stats["denied"] != 1 {
|
|
t.Fatalf("expected 1 denied, got %d", stats["denied"])
|
|
}
|
|
}
|
|
|
|
func TestApproval_Tiers(t *testing.T) {
|
|
ae := NewApprovalEngine()
|
|
tiers := ae.Tiers()
|
|
if len(tiers) != 4 {
|
|
t.Fatalf("expected 4 tiers, got %d", len(tiers))
|
|
}
|
|
}
|
|
|
|
// =====================================================
|
|
// Phase 3: Vendor Plugin Stubs
|
|
// =====================================================
|
|
|
|
func TestPlugins_RegisterDefault(t *testing.T) {
|
|
reg := NewPluginRegistry()
|
|
RegisterDefaultPlugins(reg)
|
|
|
|
// Provide required config for each vendor stub.
|
|
cfg := &IntegrationConfig{
|
|
Plugins: []PluginConfig{
|
|
{Type: PluginTypeFirewall, Vendor: "checkpoint", Enabled: true, Config: map[string]interface{}{"api_url": "https://cp.local"}},
|
|
{Type: PluginTypeEDR, Vendor: "crowdstrike", Enabled: true, Config: map[string]interface{}{"client_id": "test-id"}},
|
|
{Type: PluginTypeProxy, Vendor: "zscaler", Enabled: true, Config: map[string]interface{}{"cloud_name": "zscaler.net"}},
|
|
},
|
|
}
|
|
_ = reg.LoadPlugins(cfg)
|
|
|
|
if reg.PluginCount() != 3 {
|
|
t.Fatalf("expected 3 plugins, got %d", reg.PluginCount())
|
|
}
|
|
}
|
|
|
|
func TestPlugins_CheckPoint_Vendor(t *testing.T) {
|
|
cp := NewCheckPointEnforcer()
|
|
if cp.Vendor() != "checkpoint" {
|
|
t.Fatalf("expected 'checkpoint', got %s", cp.Vendor())
|
|
}
|
|
}
|
|
|
|
func TestPlugins_CrowdStrike_Vendor(t *testing.T) {
|
|
cs := NewCrowdStrikeController()
|
|
if cs.Vendor() != "crowdstrike" {
|
|
t.Fatalf("expected 'crowdstrike', got %s", cs.Vendor())
|
|
}
|
|
}
|
|
|
|
func TestPlugins_Zscaler_Vendor(t *testing.T) {
|
|
z := NewZscalerGateway()
|
|
if z.Vendor() != "zscaler" {
|
|
t.Fatalf("expected 'zscaler', got %s", z.Vendor())
|
|
}
|
|
}
|
|
|
|
// =====================================================
|
|
// Phase 3: Correlation Rules
|
|
// =====================================================
|
|
|
|
func TestCorrelation_RuleCount(t *testing.T) {
|
|
rules := ShadowAICorrelationRules()
|
|
if len(rules) != 9 {
|
|
t.Fatalf("expected 9 correlation rules, got %d", len(rules))
|
|
}
|
|
}
|
|
|
|
func TestCorrelation_RuleIDs(t *testing.T) {
|
|
rules := ShadowAICorrelationRules()
|
|
ids := make(map[string]bool)
|
|
for _, r := range rules {
|
|
if ids[r.ID] {
|
|
t.Fatalf("duplicate rule ID: %s", r.ID)
|
|
}
|
|
ids[r.ID] = true
|
|
}
|
|
}
|
|
|
|
func TestCorrelation_CriticalRules(t *testing.T) {
|
|
rules := ShadowAICorrelationRules()
|
|
critical := 0
|
|
for _, r := range rules {
|
|
if r.Severity == "CRITICAL" {
|
|
critical++
|
|
}
|
|
}
|
|
if critical < 3 {
|
|
t.Fatalf("expected at least 3 CRITICAL rules, got %d", critical)
|
|
}
|
|
}
|
|
|
|
// =====================================================
|
|
// Phase 3: Controller Integration (DocBridge + Approval)
|
|
// =====================================================
|
|
|
|
func TestController_ReviewDocument_Clean(t *testing.T) {
|
|
ctrl := NewShadowAIController()
|
|
result, approval := ctrl.ReviewDocument("doc-1", "clean content", "user1")
|
|
if result.Status != DocReviewClean {
|
|
t.Fatalf("expected clean, got %s", result.Status)
|
|
}
|
|
if approval == nil {
|
|
t.Fatal("expected auto-approval for clean doc")
|
|
}
|
|
if approval.Status != string(ApprovalAutoApproved) {
|
|
t.Fatalf("expected auto_approved, got %s", approval.Status)
|
|
}
|
|
}
|
|
|
|
func TestController_ReviewDocument_WithPII(t *testing.T) {
|
|
ctrl := NewShadowAIController()
|
|
result, approval := ctrl.ReviewDocument("doc-2", "Contact: alice@corp.com", "user1")
|
|
if result.Status != DocReviewRedacted {
|
|
t.Fatalf("expected redacted, got %s", result.Status)
|
|
}
|
|
if approval == nil {
|
|
t.Fatal("expected approval request for PII")
|
|
}
|
|
if approval.Status != string(ApprovalPending) {
|
|
t.Fatalf("expected pending, got %s", approval.Status)
|
|
}
|
|
}
|
|
|
|
func TestController_ReviewDocument_WithSecrets(t *testing.T) {
|
|
ctrl := NewShadowAIController()
|
|
result, approval := ctrl.ReviewDocument("doc-3", "key: AKIAIOSFODNN7EXAMPLE", "user1")
|
|
if result.Status != DocReviewBlocked {
|
|
t.Fatalf("expected blocked, got %s", result.Status)
|
|
}
|
|
// Should NOT create approval for blocked docs.
|
|
if approval != nil {
|
|
t.Fatal("blocked docs should not create approval")
|
|
}
|
|
}
|