mirror of
https://github.com/syntrex-lab/gomcp.git
synced 2026-04-26 21:06:21 +02:00
212 lines
5.6 KiB
Go
212 lines
5.6 KiB
Go
package shadow_ai
|
|
|
|
import (
|
|
"fmt"
|
|
"log/slog"
|
|
"sync"
|
|
)
|
|
|
|
// PluginFactory creates a new plugin instance.
|
|
type PluginFactory func() interface{}
|
|
|
|
// PluginRegistry manages vendor plugin registration, loading, and lifecycle.
|
|
// Thread-safe via sync.RWMutex.
|
|
type PluginRegistry struct {
|
|
mu sync.RWMutex
|
|
plugins map[string]interface{} // vendor → plugin instance
|
|
factories map[string]PluginFactory // "type_vendor" → factory
|
|
configs map[string]*PluginConfig // vendor → config
|
|
health map[string]*PluginHealth // vendor → health status
|
|
logger *slog.Logger
|
|
}
|
|
|
|
// NewPluginRegistry creates a new plugin registry.
|
|
func NewPluginRegistry() *PluginRegistry {
|
|
return &PluginRegistry{
|
|
plugins: make(map[string]interface{}),
|
|
factories: make(map[string]PluginFactory),
|
|
configs: make(map[string]*PluginConfig),
|
|
health: make(map[string]*PluginHealth),
|
|
logger: slog.Default().With("component", "shadow-ai-registry"),
|
|
}
|
|
}
|
|
|
|
// RegisterFactory registers a plugin factory for a given type+vendor combination.
|
|
// Example: RegisterFactory("firewall", "checkpoint", func() interface{} { return &CheckPointEnforcer{} })
|
|
func (r *PluginRegistry) RegisterFactory(pluginType PluginType, vendor string, factory PluginFactory) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
|
|
key := fmt.Sprintf("%s_%s", pluginType, vendor)
|
|
r.factories[key] = factory
|
|
r.logger.Info("factory registered", "type", pluginType, "vendor", vendor)
|
|
}
|
|
|
|
// LoadPlugins creates and initializes plugins from configuration.
|
|
// Plugins that fail to initialize are logged but do not block other plugins.
|
|
func (r *PluginRegistry) LoadPlugins(config *IntegrationConfig) error {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
|
|
loaded := 0
|
|
for i := range config.Plugins {
|
|
pluginCfg := &config.Plugins[i]
|
|
if !pluginCfg.Enabled {
|
|
r.logger.Debug("plugin disabled, skipping", "vendor", pluginCfg.Vendor)
|
|
continue
|
|
}
|
|
|
|
key := fmt.Sprintf("%s_%s", pluginCfg.Type, pluginCfg.Vendor)
|
|
factory, exists := r.factories[key]
|
|
if !exists {
|
|
r.logger.Warn("no factory for plugin", "key", key, "vendor", pluginCfg.Vendor)
|
|
continue
|
|
}
|
|
|
|
plugin := factory()
|
|
|
|
// Initialize if plugin supports it.
|
|
if init, ok := plugin.(Initializer); ok {
|
|
if err := init.Initialize(pluginCfg.Config); err != nil {
|
|
r.logger.Error("plugin init failed", "vendor", pluginCfg.Vendor, "error", err)
|
|
continue
|
|
}
|
|
}
|
|
|
|
r.plugins[pluginCfg.Vendor] = plugin
|
|
r.configs[pluginCfg.Vendor] = pluginCfg
|
|
r.health[pluginCfg.Vendor] = &PluginHealth{
|
|
Vendor: pluginCfg.Vendor,
|
|
Type: pluginCfg.Type,
|
|
Status: PluginStatusHealthy,
|
|
}
|
|
loaded++
|
|
r.logger.Info("plugin loaded", "vendor", pluginCfg.Vendor, "type", pluginCfg.Type)
|
|
}
|
|
|
|
r.logger.Info("plugin loading complete", "loaded", loaded, "total", len(config.Plugins))
|
|
return nil
|
|
}
|
|
|
|
// Get returns a plugin by vendor name.
|
|
func (r *PluginRegistry) Get(vendor string) (interface{}, bool) {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
p, ok := r.plugins[vendor]
|
|
return p, ok
|
|
}
|
|
|
|
// GetByType returns all plugins of a given type.
|
|
func (r *PluginRegistry) GetByType(pluginType PluginType) []interface{} {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
|
|
var result []interface{}
|
|
for vendor, cfg := range r.configs {
|
|
if cfg.Type == pluginType {
|
|
if plugin, ok := r.plugins[vendor]; ok {
|
|
result = append(result, plugin)
|
|
}
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// GetNetworkEnforcers returns all loaded NetworkEnforcer plugins.
|
|
func (r *PluginRegistry) GetNetworkEnforcers() []NetworkEnforcer {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
|
|
var result []NetworkEnforcer
|
|
for _, plugin := range r.plugins {
|
|
if ne, ok := plugin.(NetworkEnforcer); ok {
|
|
result = append(result, ne)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// GetEndpointControllers returns all loaded EndpointController plugins.
|
|
func (r *PluginRegistry) GetEndpointControllers() []EndpointController {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
|
|
var result []EndpointController
|
|
for _, plugin := range r.plugins {
|
|
if ec, ok := plugin.(EndpointController); ok {
|
|
result = append(result, ec)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// GetWebGateways returns all loaded WebGateway plugins.
|
|
func (r *PluginRegistry) GetWebGateways() []WebGateway {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
|
|
var result []WebGateway
|
|
for _, plugin := range r.plugins {
|
|
if wg, ok := plugin.(WebGateway); ok {
|
|
result = append(result, wg)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// IsHealthy returns true if a plugin is currently healthy.
|
|
func (r *PluginRegistry) IsHealthy(vendor string) bool {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
h, ok := r.health[vendor]
|
|
return ok && h.Status == PluginStatusHealthy
|
|
}
|
|
|
|
// SetHealth updates the health status for a plugin.
|
|
func (r *PluginRegistry) SetHealth(vendor string, health *PluginHealth) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.health[vendor] = health
|
|
}
|
|
|
|
// GetHealth returns the health status snapshot for a plugin.
|
|
func (r *PluginRegistry) GetHealth(vendor string) (*PluginHealth, bool) {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
h, ok := r.health[vendor]
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
cp := *h
|
|
return &cp, true
|
|
}
|
|
|
|
// AllHealth returns health snapshots for all plugins.
|
|
func (r *PluginRegistry) AllHealth() []PluginHealth {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
|
|
result := make([]PluginHealth, 0, len(r.health))
|
|
for _, h := range r.health {
|
|
result = append(result, *h)
|
|
}
|
|
return result
|
|
}
|
|
|
|
// PluginCount returns the number of loaded plugins.
|
|
func (r *PluginRegistry) PluginCount() int {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
return len(r.plugins)
|
|
}
|
|
|
|
// Vendors returns all loaded vendor names.
|
|
func (r *PluginRegistry) Vendors() []string {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
result := make([]string, 0, len(r.plugins))
|
|
for v := range r.plugins {
|
|
result = append(result, v)
|
|
}
|
|
return result
|
|
}
|