initial: Syntrex extraction from sentinel-community (615 files)

This commit is contained in:
DmitrL-dev 2026-03-11 15:12:02 +10:00
commit 2c50c993b1
175 changed files with 32396 additions and 0 deletions

View file

@ -0,0 +1,52 @@
package contextengine
import (
"encoding/json"
"os"
ctxdomain "github.com/sentinel-community/gomcp/internal/domain/context"
)
// LoadConfig loads engine configuration from a JSON file.
// If the file does not exist, returns DefaultEngineConfig.
// If the file exists but is invalid, returns an error.
func LoadConfig(path string) (ctxdomain.EngineConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return ctxdomain.DefaultEngineConfig(), nil
}
return ctxdomain.EngineConfig{}, err
}
var cfg ctxdomain.EngineConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return ctxdomain.EngineConfig{}, err
}
// Build skip set from deserialized SkipTools slice.
cfg.BuildSkipSet()
// If skip_tools was omitted in JSON, use defaults.
if cfg.SkipTools == nil {
cfg.SkipTools = ctxdomain.DefaultSkipTools()
cfg.BuildSkipSet()
}
if err := cfg.Validate(); err != nil {
return ctxdomain.EngineConfig{}, err
}
return cfg, nil
}
// SaveDefaultConfig writes the default configuration to a JSON file.
// Useful for bootstrapping .rlm/context.json.
func SaveDefaultConfig(path string) error {
cfg := ctxdomain.DefaultEngineConfig()
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0o644)
}

View file

@ -0,0 +1,110 @@
package contextengine
import (
"os"
"path/filepath"
"testing"
ctxdomain "github.com/sentinel-community/gomcp/internal/domain/context"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLoadConfig_FileNotExists(t *testing.T) {
cfg, err := LoadConfig("/nonexistent/path/context.json")
require.NoError(t, err)
assert.Equal(t, ctxdomain.DefaultTokenBudget, cfg.TokenBudget)
assert.True(t, cfg.Enabled)
assert.NotEmpty(t, cfg.SkipTools)
}
func TestLoadConfig_ValidFile(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "context.json")
content := `{
"token_budget": 500,
"max_facts": 15,
"recency_weight": 0.3,
"frequency_weight": 0.2,
"level_weight": 0.25,
"keyword_weight": 0.25,
"decay_half_life_hours": 48,
"enabled": true,
"skip_tools": ["health", "version"]
}`
err := os.WriteFile(path, []byte(content), 0o644)
require.NoError(t, err)
cfg, err := LoadConfig(path)
require.NoError(t, err)
assert.Equal(t, 500, cfg.TokenBudget)
assert.Equal(t, 15, cfg.MaxFacts)
assert.Equal(t, 0.3, cfg.RecencyWeight)
assert.Equal(t, 48.0, cfg.DecayHalfLifeHours)
assert.True(t, cfg.Enabled)
assert.Len(t, cfg.SkipTools, 2)
assert.True(t, cfg.ShouldSkip("health"))
assert.True(t, cfg.ShouldSkip("version"))
assert.False(t, cfg.ShouldSkip("search_facts"))
}
func TestLoadConfig_InvalidJSON(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "context.json")
err := os.WriteFile(path, []byte("{invalid json"), 0o644)
require.NoError(t, err)
_, err = LoadConfig(path)
assert.Error(t, err)
}
func TestLoadConfig_InvalidConfig(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "context.json")
content := `{"token_budget": 0, "max_facts": 5, "recency_weight": 0.5, "frequency_weight": 0.5, "level_weight": 0, "keyword_weight": 0, "decay_half_life_hours": 24, "enabled": true}`
err := os.WriteFile(path, []byte(content), 0o644)
require.NoError(t, err)
_, err = LoadConfig(path)
assert.Error(t, err, "should fail validation: token_budget=0")
}
func TestLoadConfig_OmittedSkipTools_UsesDefaults(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "context.json")
content := `{
"token_budget": 300,
"max_facts": 10,
"recency_weight": 0.25,
"frequency_weight": 0.15,
"level_weight": 0.30,
"keyword_weight": 0.30,
"decay_half_life_hours": 72,
"enabled": true
}`
err := os.WriteFile(path, []byte(content), 0o644)
require.NoError(t, err)
cfg, err := LoadConfig(path)
require.NoError(t, err)
assert.NotEmpty(t, cfg.SkipTools, "omitted skip_tools should use defaults")
assert.True(t, cfg.ShouldSkip("search_facts"))
}
func TestSaveDefaultConfig(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "context.json")
err := SaveDefaultConfig(path)
require.NoError(t, err)
// Verify we can load what we saved
cfg, err := LoadConfig(path)
require.NoError(t, err)
assert.Equal(t, ctxdomain.DefaultTokenBudget, cfg.TokenBudget)
assert.True(t, cfg.Enabled)
}

View file

@ -0,0 +1,203 @@
// Package contextengine implements the Proactive Context Engine.
// It automatically injects relevant memory facts into every MCP tool response
// via ToolHandlerMiddleware, so the LLM always has context without asking.
package contextengine
import (
"context"
"log"
"strings"
"sync"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
ctxdomain "github.com/sentinel-community/gomcp/internal/domain/context"
)
// InteractionLogger records tool calls for crash-safe memory.
// Implementations must be safe for concurrent use.
type InteractionLogger interface {
Record(ctx context.Context, toolName string, args map[string]interface{}) error
}
// Engine is the Proactive Context Engine. It scores facts by relevance,
// selects the top ones within a token budget, and injects them into
// every tool response as a [MEMORY CONTEXT] block.
type Engine struct {
config ctxdomain.EngineConfig
scorer *ctxdomain.RelevanceScorer
provider ctxdomain.FactProvider
logger InteractionLogger // optional, nil = no logging
mu sync.RWMutex
accessCounts map[string]int // in-memory access counters per fact ID
}
// New creates a new Proactive Context Engine.
func New(cfg ctxdomain.EngineConfig, provider ctxdomain.FactProvider) *Engine {
return &Engine{
config: cfg,
scorer: ctxdomain.NewRelevanceScorer(cfg),
provider: provider,
accessCounts: make(map[string]int),
}
}
// SetInteractionLogger attaches an optional interaction logger for crash-safe
// tool call recording. If set, every tool call passing through the middleware
// will be recorded fire-and-forget (errors logged, never propagated).
func (e *Engine) SetInteractionLogger(l InteractionLogger) {
e.logger = l
}
// IsEnabled returns whether the engine is active.
func (e *Engine) IsEnabled() bool {
return e.config.Enabled
}
// BuildContext scores and selects relevant facts for the given tool call,
// returning a ContextFrame ready for formatting and injection.
func (e *Engine) BuildContext(toolName string, args map[string]interface{}) *ctxdomain.ContextFrame {
frame := ctxdomain.NewContextFrame(toolName, e.config.TokenBudget)
if !e.config.Enabled {
return frame
}
// Extract keywords from all string arguments
keywords := e.extractKeywordsFromArgs(args)
// Get candidate facts from provider
facts, err := e.provider.GetRelevantFacts(args)
if err != nil || len(facts) == 0 {
return frame
}
// Get current access counts snapshot
e.mu.RLock()
countsCopy := make(map[string]int, len(e.accessCounts))
for k, v := range e.accessCounts {
countsCopy[k] = v
}
e.mu.RUnlock()
// Score and rank facts
ranked := e.scorer.RankFacts(facts, keywords, countsCopy)
// Fill frame within token budget and max facts
added := 0
for _, sf := range ranked {
if added >= e.config.MaxFacts {
break
}
if frame.AddFact(sf) {
added++
// Record access for reinforcement
e.recordAccessInternal(sf.Fact.ID)
e.provider.RecordAccess(sf.Fact.ID)
}
}
return frame
}
// GetAccessCount returns the internal access count for a fact.
func (e *Engine) GetAccessCount(factID string) int {
e.mu.RLock()
defer e.mu.RUnlock()
return e.accessCounts[factID]
}
// recordAccessInternal increments the in-memory access counter.
func (e *Engine) recordAccessInternal(factID string) {
e.mu.Lock()
e.accessCounts[factID]++
e.mu.Unlock()
}
// Middleware returns a ToolHandlerMiddleware that wraps every tool handler
// to inject relevant memory context into the response and optionally
// record tool calls to the interaction log for crash-safe memory.
func (e *Engine) Middleware() server.ToolHandlerMiddleware {
return func(next server.ToolHandlerFunc) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Fire-and-forget: record this tool call in the interaction log.
// This runs BEFORE the handler so the record is persisted even if
// the process is killed mid-handler.
if e.logger != nil {
if logErr := e.logger.Record(ctx, req.Params.Name, req.GetArguments()); logErr != nil {
log.Printf("contextengine: interaction log error: %v", logErr)
}
}
// Call the original handler
result, err := next(ctx, req)
if err != nil {
return result, err
}
// Don't inject on nil result, error results, or empty content
if result == nil || result.IsError || len(result.Content) == 0 {
return result, nil
}
// Don't inject if engine is disabled
if !e.IsEnabled() {
return result, nil
}
// Don't inject for tools in the skip list
if e.config.ShouldSkip(req.Params.Name) {
return result, nil
}
// Build context frame
frame := e.BuildContext(req.Params.Name, req.GetArguments())
contextText := frame.Format()
if contextText == "" {
return result, nil
}
// Append context to the last text content block
e.appendContextToResult(result, contextText)
return result, nil
}
}
}
// appendContextToResult appends the context text to the last TextContent in the result.
func (e *Engine) appendContextToResult(result *mcp.CallToolResult, contextText string) {
for i := len(result.Content) - 1; i >= 0; i-- {
if tc, ok := result.Content[i].(mcp.TextContent); ok {
tc.Text += contextText
result.Content[i] = tc
return
}
}
// No text content found — add a new one
result.Content = append(result.Content, mcp.TextContent{
Type: "text",
Text: contextText,
})
}
// extractKeywordsFromArgs extracts keywords from all string values in the arguments map.
func (e *Engine) extractKeywordsFromArgs(args map[string]interface{}) []string {
if len(args) == 0 {
return nil
}
var allText strings.Builder
for _, v := range args {
switch val := v.(type) {
case string:
allText.WriteString(val)
allText.WriteString(" ")
}
}
return ctxdomain.ExtractKeywords(allText.String())
}

View file

@ -0,0 +1,708 @@
package contextengine
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/mark3labs/mcp-go/mcp"
"github.com/sentinel-community/gomcp/internal/domain/memory"
ctxdomain "github.com/sentinel-community/gomcp/internal/domain/context"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Mock FactProvider ---
type mockProvider struct {
mu sync.Mutex
facts []*memory.Fact
l0 []*memory.Fact
// tracks RecordAccess calls
accessed map[string]int
}
func newMockProvider(facts ...*memory.Fact) *mockProvider {
l0 := make([]*memory.Fact, 0)
for _, f := range facts {
if f.Level == memory.LevelProject {
l0 = append(l0, f)
}
}
return &mockProvider{
facts: facts,
l0: l0,
accessed: make(map[string]int),
}
}
func (m *mockProvider) GetRelevantFacts(_ map[string]interface{}) ([]*memory.Fact, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.facts, nil
}
func (m *mockProvider) GetL0Facts() ([]*memory.Fact, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.l0, nil
}
func (m *mockProvider) RecordAccess(factID string) {
m.mu.Lock()
defer m.mu.Unlock()
m.accessed[factID]++
}
// --- Engine tests ---
func TestNewEngine(t *testing.T) {
cfg := ctxdomain.DefaultEngineConfig()
provider := newMockProvider()
engine := New(cfg, provider)
require.NotNil(t, engine)
assert.True(t, engine.IsEnabled())
}
func TestNewEngine_Disabled(t *testing.T) {
cfg := ctxdomain.DefaultEngineConfig()
cfg.Enabled = false
engine := New(cfg, newMockProvider())
assert.False(t, engine.IsEnabled())
}
func TestEngine_BuildContext_NoFacts(t *testing.T) {
cfg := ctxdomain.DefaultEngineConfig()
provider := newMockProvider()
engine := New(cfg, provider)
frame := engine.BuildContext("test_tool", map[string]interface{}{
"content": "hello world",
})
assert.NotNil(t, frame)
assert.Empty(t, frame.Facts)
assert.Equal(t, "", frame.Format())
}
func TestEngine_BuildContext_WithFacts(t *testing.T) {
fact1 := memory.NewFact("Architecture uses clean layers", memory.LevelProject, "arch", "")
fact2 := memory.NewFact("TDD is mandatory for all code", memory.LevelProject, "process", "")
fact3 := memory.NewFact("Random snippet from old session", memory.LevelSnippet, "misc", "")
fact3.CreatedAt = time.Now().Add(-90 * 24 * time.Hour) // very old
provider := newMockProvider(fact1, fact2, fact3)
cfg := ctxdomain.DefaultEngineConfig()
cfg.TokenBudget = 500
engine := New(cfg, provider)
frame := engine.BuildContext("add_fact", map[string]interface{}{
"content": "architecture decision",
})
require.NotNil(t, frame)
assert.NotEmpty(t, frame.Facts)
// L0 facts should be included and ranked higher
assert.Equal(t, "add_fact", frame.ToolName)
formatted := frame.Format()
assert.Contains(t, formatted, "[MEMORY CONTEXT]")
assert.Contains(t, formatted, "[/MEMORY CONTEXT]")
}
func TestEngine_BuildContext_RespectsTokenBudget(t *testing.T) {
// Create many facts that exceed token budget
facts := make([]*memory.Fact, 50)
for i := 0; i < 50; i++ {
facts[i] = memory.NewFact(
fmt.Sprintf("Fact number %d with enough content to consume tokens in the budget allocation system", i),
memory.LevelProject, "arch", "",
)
}
provider := newMockProvider(facts...)
cfg := ctxdomain.DefaultEngineConfig()
cfg.TokenBudget = 100 // tight budget
cfg.MaxFacts = 50
engine := New(cfg, provider)
frame := engine.BuildContext("test", map[string]interface{}{"query": "test"})
assert.LessOrEqual(t, frame.TokensUsed, cfg.TokenBudget)
}
func TestEngine_BuildContext_RespectsMaxFacts(t *testing.T) {
facts := make([]*memory.Fact, 20)
for i := 0; i < 20; i++ {
facts[i] = memory.NewFact(fmt.Sprintf("Fact %d", i), memory.LevelProject, "arch", "")
}
provider := newMockProvider(facts...)
cfg := ctxdomain.DefaultEngineConfig()
cfg.MaxFacts = 5
cfg.TokenBudget = 10000 // large budget so max_facts is the limiter
engine := New(cfg, provider)
frame := engine.BuildContext("test", map[string]interface{}{"query": "fact"})
assert.LessOrEqual(t, len(frame.Facts), cfg.MaxFacts)
}
func TestEngine_BuildContext_DisabledReturnsEmpty(t *testing.T) {
fact := memory.NewFact("test", memory.LevelProject, "arch", "")
provider := newMockProvider(fact)
cfg := ctxdomain.DefaultEngineConfig()
cfg.Enabled = false
engine := New(cfg, provider)
frame := engine.BuildContext("test", map[string]interface{}{"content": "test"})
assert.Empty(t, frame.Facts)
assert.Equal(t, "", frame.Format())
}
func TestEngine_RecordsAccess(t *testing.T) {
fact1 := memory.NewFact("Architecture pattern", memory.LevelProject, "arch", "")
provider := newMockProvider(fact1)
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
frame := engine.BuildContext("test", map[string]interface{}{"content": "architecture"})
require.NotEmpty(t, frame.Facts)
// Check that RecordAccess was called on the provider
provider.mu.Lock()
count := provider.accessed[fact1.ID]
provider.mu.Unlock()
assert.Greater(t, count, 0, "RecordAccess should be called for injected facts")
}
func TestEngine_AccessCountTracking(t *testing.T) {
fact := memory.NewFact("Architecture decision", memory.LevelProject, "arch", "")
provider := newMockProvider(fact)
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
// Build context 3 times
for i := 0; i < 3; i++ {
engine.BuildContext("test", map[string]interface{}{"content": "architecture"})
}
// Internal access count should be tracked
count := engine.GetAccessCount(fact.ID)
assert.Equal(t, 3, count)
}
func TestEngine_AccessCountInfluencesRanking(t *testing.T) {
// Two similar facts but one has been accessed more
fact1 := memory.NewFact("Architecture pattern A", memory.LevelDomain, "arch", "")
fact2 := memory.NewFact("Architecture pattern B", memory.LevelDomain, "arch", "")
provider := newMockProvider(fact1, fact2)
cfg := ctxdomain.DefaultEngineConfig()
cfg.FrequencyWeight = 0.9 // heavily weight frequency
cfg.KeywordWeight = 0.01
cfg.RecencyWeight = 0.01
cfg.LevelWeight = 0.01
engine := New(cfg, provider)
// Simulate fact1 being accessed many times
for i := 0; i < 20; i++ {
engine.recordAccessInternal(fact1.ID)
}
frame := engine.BuildContext("test", map[string]interface{}{"content": "architecture pattern"})
require.GreaterOrEqual(t, len(frame.Facts), 2)
// fact1 should rank higher due to frequency
assert.Equal(t, fact1.ID, frame.Facts[0].Fact.ID)
}
// --- Middleware tests ---
func TestMiddleware_InjectsContext(t *testing.T) {
fact := memory.NewFact("Always remember: TDD first", memory.LevelProject, "process", "")
provider := newMockProvider(fact)
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
// Create a simple handler
handler := func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "Original result"},
},
}, nil
}
// Wrap with middleware
wrapped := engine.Middleware()(handler)
req := mcp.CallToolRequest{}
req.Params.Name = "test_tool"
req.Params.Arguments = map[string]interface{}{
"content": "TDD process",
}
result, err := wrapped(context.Background(), req)
require.NoError(t, err)
require.NotNil(t, result)
require.Len(t, result.Content, 1)
text := result.Content[0].(mcp.TextContent).Text
assert.Contains(t, text, "Original result")
assert.Contains(t, text, "[MEMORY CONTEXT]")
assert.Contains(t, text, "TDD first")
}
func TestMiddleware_DisabledPassesThrough(t *testing.T) {
fact := memory.NewFact("should not appear", memory.LevelProject, "test", "")
provider := newMockProvider(fact)
cfg := ctxdomain.DefaultEngineConfig()
cfg.Enabled = false
engine := New(cfg, provider)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "Original only"},
},
}, nil
}
wrapped := engine.Middleware()(handler)
result, err := wrapped(context.Background(), mcp.CallToolRequest{})
require.NoError(t, err)
text := result.Content[0].(mcp.TextContent).Text
assert.Equal(t, "Original only", text)
assert.NotContains(t, text, "[MEMORY CONTEXT]")
}
func TestMiddleware_HandlerErrorPassedThrough(t *testing.T) {
provider := newMockProvider()
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return nil, fmt.Errorf("handler error")
}
wrapped := engine.Middleware()(handler)
result, err := wrapped(context.Background(), mcp.CallToolRequest{})
assert.Error(t, err)
assert.Nil(t, result)
}
func TestMiddleware_ErrorResult_NoInjection(t *testing.T) {
fact := memory.NewFact("should not appear on errors", memory.LevelProject, "test", "")
provider := newMockProvider(fact)
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "Error: something failed"},
},
IsError: true,
}, nil
}
wrapped := engine.Middleware()(handler)
result, err := wrapped(context.Background(), mcp.CallToolRequest{})
require.NoError(t, err)
text := result.Content[0].(mcp.TextContent).Text
assert.NotContains(t, text, "[MEMORY CONTEXT]", "should not inject context on error results")
}
func TestMiddleware_EmptyContentSlice(t *testing.T) {
provider := newMockProvider(memory.NewFact("test", memory.LevelProject, "a", ""))
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{},
}, nil
}
wrapped := engine.Middleware()(handler)
result, err := wrapped(context.Background(), mcp.CallToolRequest{})
require.NoError(t, err)
// Should handle empty content gracefully
assert.NotNil(t, result)
}
func TestMiddleware_NilResult(t *testing.T) {
provider := newMockProvider()
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return nil, nil
}
wrapped := engine.Middleware()(handler)
result, err := wrapped(context.Background(), mcp.CallToolRequest{})
require.NoError(t, err)
assert.Nil(t, result)
}
func TestMiddleware_SkipListTools(t *testing.T) {
fact := memory.NewFact("Should not appear for skipped tools", memory.LevelProject, "test", "")
provider := newMockProvider(fact)
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "Facts result"},
},
}, nil
}
wrapped := engine.Middleware()(handler)
// Tools in default skip list should NOT get context injected
skipTools := []string{"search_facts", "get_fact", "get_l0_facts", "health", "version", "dashboard"}
for _, tool := range skipTools {
t.Run(tool, func(t *testing.T) {
req := mcp.CallToolRequest{}
req.Params.Name = tool
req.Params.Arguments = map[string]interface{}{"query": "test"}
result, err := wrapped(context.Background(), req)
require.NoError(t, err)
text := result.Content[0].(mcp.TextContent).Text
assert.NotContains(t, text, "[MEMORY CONTEXT]",
"tool %s is in skip list, should not get context injected", tool)
})
}
}
func TestMiddleware_NonSkipToolGetsContext(t *testing.T) {
fact := memory.NewFact("Important architecture fact", memory.LevelProject, "arch", "")
provider := newMockProvider(fact)
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "Tool result"},
},
}, nil
}
wrapped := engine.Middleware()(handler)
req := mcp.CallToolRequest{}
req.Params.Name = "add_causal_node"
req.Params.Arguments = map[string]interface{}{"content": "architecture decision"}
result, err := wrapped(context.Background(), req)
require.NoError(t, err)
text := result.Content[0].(mcp.TextContent).Text
assert.Contains(t, text, "[MEMORY CONTEXT]")
}
// --- Concurrency test ---
func TestEngine_ConcurrentAccess(t *testing.T) {
facts := make([]*memory.Fact, 10)
for i := 0; i < 10; i++ {
facts[i] = memory.NewFact(fmt.Sprintf("Concurrent fact %d", i), memory.LevelProject, "arch", "")
}
provider := newMockProvider(facts...)
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
var wg sync.WaitGroup
for i := 0; i < 20; i++ {
wg.Add(1)
go func(n int) {
defer wg.Done()
engine.BuildContext("tool", map[string]interface{}{
"content": fmt.Sprintf("query %d", n),
})
}(i)
}
wg.Wait()
// Just verify no panics or races (run with -race)
for _, f := range facts {
count := engine.GetAccessCount(f.ID)
assert.GreaterOrEqual(t, count, 0)
}
}
// --- Benchmark ---
func BenchmarkEngine_BuildContext(b *testing.B) {
facts := make([]*memory.Fact, 100)
for i := 0; i < 100; i++ {
facts[i] = memory.NewFact(
"Architecture uses clean layers with dependency injection for modularity",
memory.HierLevel(i%4), "arch", "core",
)
}
provider := newMockProvider(facts...)
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
args := map[string]interface{}{
"content": "architecture clean layers dependency",
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
engine.BuildContext("test_tool", args)
}
}
func BenchmarkMiddleware(b *testing.B) {
facts := make([]*memory.Fact, 50)
for i := 0; i < 50; i++ {
facts[i] = memory.NewFact("test fact content", memory.LevelProject, "arch", "")
}
provider := newMockProvider(facts...)
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "result"},
},
}, nil
}
wrapped := engine.Middleware()(handler)
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{"content": "test"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = wrapped(context.Background(), req)
}
}
// --- Mock InteractionLogger ---
type mockInteractionLogger struct {
mu sync.Mutex
entries []logEntry
failErr error // if set, Record returns this error
}
type logEntry struct {
toolName string
args map[string]interface{}
}
func (m *mockInteractionLogger) Record(_ context.Context, toolName string, args map[string]interface{}) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.failErr != nil {
return m.failErr
}
m.entries = append(m.entries, logEntry{toolName: toolName, args: args})
return nil
}
func (m *mockInteractionLogger) count() int {
m.mu.Lock()
defer m.mu.Unlock()
return len(m.entries)
}
func (m *mockInteractionLogger) lastToolName() string {
m.mu.Lock()
defer m.mu.Unlock()
if len(m.entries) == 0 {
return ""
}
return m.entries[len(m.entries)-1].toolName
}
// --- Interaction Logger Tests ---
func TestMiddleware_InteractionLogger_RecordsToolCalls(t *testing.T) {
provider := newMockProvider()
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
logger := &mockInteractionLogger{}
engine.SetInteractionLogger(logger)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "ok"},
},
}, nil
}
wrapped := engine.Middleware()(handler)
req := mcp.CallToolRequest{}
req.Params.Name = "add_fact"
req.Params.Arguments = map[string]interface{}{"content": "test fact"}
_, err := wrapped(context.Background(), req)
require.NoError(t, err)
assert.Equal(t, 1, logger.count())
assert.Equal(t, "add_fact", logger.lastToolName())
}
func TestMiddleware_InteractionLogger_RecordsSkippedTools(t *testing.T) {
// Even skip-list tools should be recorded in the interaction log
// (skip-list only controls context injection, not logging)
provider := newMockProvider()
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
logger := &mockInteractionLogger{}
engine.SetInteractionLogger(logger)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "ok"},
},
}, nil
}
wrapped := engine.Middleware()(handler)
// "health" is in the skip list
req := mcp.CallToolRequest{}
req.Params.Name = "health"
_, err := wrapped(context.Background(), req)
require.NoError(t, err)
assert.Equal(t, 1, logger.count(), "skip-list tools should still be logged")
assert.Equal(t, "health", logger.lastToolName())
}
func TestMiddleware_InteractionLogger_ErrorDoesNotBreakHandler(t *testing.T) {
// Logger errors must be swallowed — never break the tool call
provider := newMockProvider()
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
logger := &mockInteractionLogger{failErr: fmt.Errorf("disk full")}
engine.SetInteractionLogger(logger)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "handler succeeded"},
},
}, nil
}
wrapped := engine.Middleware()(handler)
req := mcp.CallToolRequest{}
req.Params.Name = "add_fact"
result, err := wrapped(context.Background(), req)
require.NoError(t, err, "logger error must not propagate")
require.NotNil(t, result)
text := result.Content[0].(mcp.TextContent).Text
assert.Contains(t, text, "handler succeeded")
}
func TestMiddleware_NoLogger_StillWorks(t *testing.T) {
// Without a logger set, middleware should work normally
provider := newMockProvider()
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
// engine.logger is nil — no SetInteractionLogger call
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "no logger ok"},
},
}, nil
}
wrapped := engine.Middleware()(handler)
req := mcp.CallToolRequest{}
req.Params.Name = "add_fact"
result, err := wrapped(context.Background(), req)
require.NoError(t, err)
require.NotNil(t, result)
}
func TestMiddleware_InteractionLogger_MultipleToolCalls(t *testing.T) {
provider := newMockProvider()
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
logger := &mockInteractionLogger{}
engine.SetInteractionLogger(logger)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "ok"},
},
}, nil
}
wrapped := engine.Middleware()(handler)
toolNames := []string{"add_fact", "search_facts", "health", "add_causal_node", "version"}
for _, name := range toolNames {
req := mcp.CallToolRequest{}
req.Params.Name = name
_, _ = wrapped(context.Background(), req)
}
assert.Equal(t, 5, logger.count(), "all 5 tool calls should be logged")
}
func TestMiddleware_InteractionLogger_ConcurrentCalls(t *testing.T) {
provider := newMockProvider()
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
logger := &mockInteractionLogger{}
engine.SetInteractionLogger(logger)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "ok"},
},
}, nil
}
wrapped := engine.Middleware()(handler)
var wg sync.WaitGroup
for i := 0; i < 20; i++ {
wg.Add(1)
go func(n int) {
defer wg.Done()
req := mcp.CallToolRequest{}
req.Params.Name = fmt.Sprintf("tool_%d", n)
_, _ = wrapped(context.Background(), req)
}(i)
}
wg.Wait()
assert.Equal(t, 20, logger.count(), "all 20 concurrent calls should be logged")
}

View file

@ -0,0 +1,245 @@
// Package contextengine — processor.go
// Processes unprocessed interaction log entries into session summary facts.
// This closes the memory loop: tool calls → interaction log → summary facts → boot instructions.
package contextengine
import (
"context"
"fmt"
"sort"
"strings"
"time"
"github.com/sentinel-community/gomcp/internal/domain/memory"
"github.com/sentinel-community/gomcp/internal/infrastructure/sqlite"
)
// InteractionProcessor processes unprocessed interaction log entries
// and creates session summary facts from them.
type InteractionProcessor struct {
repo *sqlite.InteractionLogRepo
factStore memory.FactStore
}
// NewInteractionProcessor creates a new processor.
func NewInteractionProcessor(repo *sqlite.InteractionLogRepo, store memory.FactStore) *InteractionProcessor {
return &InteractionProcessor{repo: repo, factStore: store}
}
// ProcessStartup processes unprocessed entries from a previous (possibly crashed) session.
// It creates an L1 "session summary" fact and marks all entries as processed.
// Returns the summary text (empty if nothing to process).
func (p *InteractionProcessor) ProcessStartup(ctx context.Context) (string, error) {
entries, err := p.repo.GetUnprocessed(ctx)
if err != nil {
return "", fmt.Errorf("get unprocessed: %w", err)
}
if len(entries) == 0 {
return "", nil
}
summary := buildSessionSummary(entries, "previous session (recovered)")
if summary == "" {
return "", nil
}
// Save as L1 fact (domain-level, not project-level)
fact := memory.NewFact(summary, memory.LevelDomain, "session-history", "interaction-processor")
fact.Source = "auto:interaction-processor"
if err := p.factStore.Add(ctx, fact); err != nil {
return "", fmt.Errorf("save session summary fact: %w", err)
}
// Mark all as processed
ids := make([]int64, len(entries))
for i, e := range entries {
ids[i] = e.ID
}
if err := p.repo.MarkProcessed(ctx, ids); err != nil {
return "", fmt.Errorf("mark processed: %w", err)
}
return summary, nil
}
// ProcessShutdown processes entries from the current session at graceful shutdown.
// Similar to ProcessStartup but labels differently.
func (p *InteractionProcessor) ProcessShutdown(ctx context.Context) (string, error) {
entries, err := p.repo.GetUnprocessed(ctx)
if err != nil {
return "", fmt.Errorf("get unprocessed: %w", err)
}
if len(entries) == 0 {
return "", nil
}
summary := buildSessionSummary(entries, "session ending "+time.Now().Format("2006-01-02 15:04"))
if summary == "" {
return "", nil
}
fact := memory.NewFact(summary, memory.LevelDomain, "session-history", "interaction-processor")
fact.Source = "auto:session-shutdown"
if err := p.factStore.Add(ctx, fact); err != nil {
return "", fmt.Errorf("save session summary fact: %w", err)
}
ids := make([]int64, len(entries))
for i, e := range entries {
ids[i] = e.ID
}
if err := p.repo.MarkProcessed(ctx, ids); err != nil {
return "", fmt.Errorf("mark processed: %w", err)
}
return summary, nil
}
// buildSessionSummary creates a compact text summary from interaction log entries.
func buildSessionSummary(entries []sqlite.InteractionEntry, label string) string {
if len(entries) == 0 {
return ""
}
// Count tool calls
toolCounts := make(map[string]int)
for _, e := range entries {
toolCounts[e.ToolName]++
}
// Sort by count descending
type toolStat struct {
name string
count int
}
stats := make([]toolStat, 0, len(toolCounts))
for name, count := range toolCounts {
stats = append(stats, toolStat{name, count})
}
sort.Slice(stats, func(i, j int) bool { return stats[i].count > stats[j].count })
// Extract topics from args (unique string values)
topics := extractTopicsFromEntries(entries)
// Time range
var earliest, latest time.Time
for _, e := range entries {
if earliest.IsZero() || e.Timestamp.Before(earliest) {
earliest = e.Timestamp
}
if latest.IsZero() || e.Timestamp.After(latest) {
latest = e.Timestamp
}
}
var b strings.Builder
b.WriteString(fmt.Sprintf("Session summary (%s): %d tool calls", label, len(entries)))
if !earliest.IsZero() {
duration := latest.Sub(earliest)
if duration > 0 {
b.WriteString(fmt.Sprintf(" over %s", formatDuration(duration)))
}
}
b.WriteString(". ")
// Top tools used
b.WriteString("Tools used: ")
for i, ts := range stats {
if i >= 8 {
b.WriteString(fmt.Sprintf(" +%d more", len(stats)-8))
break
}
if i > 0 {
b.WriteString(", ")
}
b.WriteString(fmt.Sprintf("%s(%d)", ts.name, ts.count))
}
b.WriteString(". ")
// Topics
if len(topics) > 0 {
b.WriteString("Topics: ")
limit := 10
if len(topics) < limit {
limit = len(topics)
}
b.WriteString(strings.Join(topics[:limit], ", "))
if len(topics) > limit {
b.WriteString(fmt.Sprintf(" +%d more", len(topics)-limit))
}
b.WriteString(".")
}
return b.String()
}
// extractTopicsFromEntries pulls unique meaningful strings from tool arguments.
func extractTopicsFromEntries(entries []sqlite.InteractionEntry) []string {
seen := make(map[string]bool)
var topics []string
for _, e := range entries {
if e.ArgsJSON == "" {
continue
}
// Simple extraction: find quoted strings in JSON args
// ArgsJSON looks like {"query":"architecture","content":"some fact"}
parts := strings.Split(e.ArgsJSON, "\"")
for i := 3; i < len(parts); i += 4 {
// Values are at odd positions after the key
val := parts[i]
if len(val) < 3 || len(val) > 100 {
continue
}
// Skip common non-topic values
lower := strings.ToLower(val)
if lower == "true" || lower == "false" || lower == "null" || lower == "" {
continue
}
if !seen[lower] {
seen[lower] = true
topics = append(topics, val)
}
}
}
return topics
}
// formatDuration formats a duration into a human-readable string.
func formatDuration(d time.Duration) string {
if d < time.Minute {
return fmt.Sprintf("%ds", int(d.Seconds()))
}
if d < time.Hour {
return fmt.Sprintf("%dm", int(d.Minutes()))
}
return fmt.Sprintf("%dh%dm", int(d.Hours()), int(d.Minutes())%60)
}
// GetLastSessionSummary searches the fact store for the most recent session summary.
func GetLastSessionSummary(ctx context.Context, store memory.FactStore) string {
facts, err := store.Search(ctx, "Session summary", 5)
if err != nil || len(facts) == 0 {
return ""
}
// Find the most recent one from session-history domain
var best *memory.Fact
for _, f := range facts {
if f.Domain != "session-history" {
continue
}
if f.IsStale || f.IsArchived {
continue
}
if best == nil || f.CreatedAt.After(best.CreatedAt) {
best = f
}
}
if best == nil {
return ""
}
return best.Content
}

View file

@ -0,0 +1,251 @@
package contextengine
import (
"context"
"testing"
"time"
"github.com/sentinel-community/gomcp/internal/domain/memory"
"github.com/sentinel-community/gomcp/internal/infrastructure/sqlite"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- mock FactStore for processor tests ---
type procMockFactStore struct {
facts []*memory.Fact
}
func (m *procMockFactStore) Add(_ context.Context, f *memory.Fact) error {
m.facts = append(m.facts, f)
return nil
}
func (m *procMockFactStore) Get(_ context.Context, id string) (*memory.Fact, error) {
for _, f := range m.facts {
if f.ID == id {
return f, nil
}
}
return nil, nil
}
func (m *procMockFactStore) Update(_ context.Context, _ *memory.Fact) error { return nil }
func (m *procMockFactStore) Delete(_ context.Context, _ string) error { return nil }
func (m *procMockFactStore) ListByDomain(_ context.Context, _ string, _ bool) ([]*memory.Fact, error) {
return nil, nil
}
func (m *procMockFactStore) ListByLevel(_ context.Context, _ memory.HierLevel) ([]*memory.Fact, error) {
return nil, nil
}
func (m *procMockFactStore) ListDomains(_ context.Context) ([]string, error) { return nil, nil }
func (m *procMockFactStore) GetStale(_ context.Context, _ bool) ([]*memory.Fact, error) {
return nil, nil
}
func (m *procMockFactStore) Search(_ context.Context, query string, limit int) ([]*memory.Fact, error) {
var results []*memory.Fact
for _, f := range m.facts {
if len(results) >= limit {
break
}
if contains(f.Content, query) {
results = append(results, f)
}
}
return results, nil
}
func (m *procMockFactStore) GetExpired(_ context.Context) ([]*memory.Fact, error) { return nil, nil }
func (m *procMockFactStore) RefreshTTL(_ context.Context, _ string) error { return nil }
func (m *procMockFactStore) TouchFact(_ context.Context, _ string) error { return nil }
func (m *procMockFactStore) GetColdFacts(_ context.Context, _ int) ([]*memory.Fact, error) {
return nil, nil
}
func (m *procMockFactStore) CompressFacts(_ context.Context, _ []string, _ string) (string, error) {
return "", nil
}
func (m *procMockFactStore) Stats(_ context.Context) (*memory.FactStoreStats, error) {
return &memory.FactStoreStats{}, nil
}
func (m *procMockFactStore) ListGenes(_ context.Context) ([]*memory.Fact, error) { return nil, nil }
func contains(s, sub string) bool {
return len(s) >= len(sub) && (s == sub || len(sub) == 0 ||
(len(s) > 0 && len(sub) > 0 && containsStr(s, sub)))
}
func containsStr(s, sub string) bool {
for i := 0; i <= len(s)-len(sub); i++ {
if s[i:i+len(sub)] == sub {
return true
}
}
return false
}
// --- Tests ---
func TestInteractionProcessor_ProcessStartup_NoEntries(t *testing.T) {
db, err := sqlite.OpenMemory()
require.NoError(t, err)
defer db.Close()
repo, err := sqlite.NewInteractionLogRepo(db)
require.NoError(t, err)
store := &procMockFactStore{}
proc := NewInteractionProcessor(repo, store)
summary, err := proc.ProcessStartup(context.Background())
require.NoError(t, err)
assert.Empty(t, summary)
assert.Empty(t, store.facts, "no facts should be created")
}
func TestInteractionProcessor_ProcessStartup_CreatesFactAndMarksProcessed(t *testing.T) {
db, err := sqlite.OpenMemory()
require.NoError(t, err)
defer db.Close()
repo, err := sqlite.NewInteractionLogRepo(db)
require.NoError(t, err)
ctx := context.Background()
// Insert some tool calls
require.NoError(t, repo.Record(ctx, "add_fact", map[string]interface{}{"content": "test fact about architecture"}))
require.NoError(t, repo.Record(ctx, "search_facts", map[string]interface{}{"query": "security"}))
require.NoError(t, repo.Record(ctx, "health", nil))
require.NoError(t, repo.Record(ctx, "add_fact", map[string]interface{}{"content": "another fact"}))
require.NoError(t, repo.Record(ctx, "dashboard", nil))
store := &procMockFactStore{}
proc := NewInteractionProcessor(repo, store)
summary, err := proc.ProcessStartup(ctx)
require.NoError(t, err)
assert.NotEmpty(t, summary)
assert.Contains(t, summary, "Session summary")
assert.Contains(t, summary, "5 tool calls")
assert.Contains(t, summary, "add_fact(2)")
// Fact should be saved
require.Len(t, store.facts, 1)
assert.Equal(t, memory.LevelDomain, store.facts[0].Level)
assert.Equal(t, "session-history", store.facts[0].Domain)
assert.Equal(t, "auto:interaction-processor", store.facts[0].Source)
// All entries should be marked processed
_, unprocessed, err := repo.Count(ctx)
require.NoError(t, err)
assert.Equal(t, 0, unprocessed)
}
func TestInteractionProcessor_ProcessShutdown(t *testing.T) {
db, err := sqlite.OpenMemory()
require.NoError(t, err)
defer db.Close()
repo, err := sqlite.NewInteractionLogRepo(db)
require.NoError(t, err)
ctx := context.Background()
require.NoError(t, repo.Record(ctx, "version", nil))
require.NoError(t, repo.Record(ctx, "search_facts", map[string]interface{}{"query": "gomcp"}))
store := &procMockFactStore{}
proc := NewInteractionProcessor(repo, store)
summary, err := proc.ProcessShutdown(ctx)
require.NoError(t, err)
assert.Contains(t, summary, "session ending")
assert.Contains(t, summary, "2 tool calls")
require.Len(t, store.facts, 1)
assert.Equal(t, "auto:session-shutdown", store.facts[0].Source)
}
func TestBuildSessionSummary_ToolCounts(t *testing.T) {
now := time.Now()
entries := []sqlite.InteractionEntry{
{ID: 1, ToolName: "add_fact", Timestamp: now},
{ID: 2, ToolName: "add_fact", Timestamp: now},
{ID: 3, ToolName: "add_fact", Timestamp: now},
{ID: 4, ToolName: "search_facts", Timestamp: now},
{ID: 5, ToolName: "health", Timestamp: now},
}
summary := buildSessionSummary(entries, "test")
assert.Contains(t, summary, "5 tool calls")
assert.Contains(t, summary, "add_fact(3)")
assert.Contains(t, summary, "search_facts(1)")
assert.Contains(t, summary, "health(1)")
}
func TestBuildSessionSummary_Duration(t *testing.T) {
now := time.Now()
entries := []sqlite.InteractionEntry{
{ID: 1, ToolName: "a", Timestamp: now.Add(-30 * time.Minute)},
{ID: 2, ToolName: "b", Timestamp: now},
}
summary := buildSessionSummary(entries, "test")
assert.Contains(t, summary, "30m")
}
func TestBuildSessionSummary_Empty(t *testing.T) {
summary := buildSessionSummary(nil, "test")
assert.Empty(t, summary)
}
func TestBuildSessionSummary_Topics(t *testing.T) {
now := time.Now()
entries := []sqlite.InteractionEntry{
{ID: 1, ToolName: "search_facts", ArgsJSON: `{"query":"architecture"}`, Timestamp: now},
{ID: 2, ToolName: "add_fact", ArgsJSON: `{"content":"security review"}`, Timestamp: now},
}
summary := buildSessionSummary(entries, "test")
assert.Contains(t, summary, "Topics:")
assert.Contains(t, summary, "architecture")
}
func TestGetLastSessionSummary_Found(t *testing.T) {
store := &procMockFactStore{}
f := memory.NewFact("Session summary (test): 5 tool calls", memory.LevelDomain, "session-history", "")
f.Source = "auto:session-shutdown"
store.facts = append(store.facts, f)
result := GetLastSessionSummary(context.Background(), store)
assert.Contains(t, result, "Session summary")
}
func TestGetLastSessionSummary_NotFound(t *testing.T) {
store := &procMockFactStore{}
result := GetLastSessionSummary(context.Background(), store)
assert.Empty(t, result)
}
func TestGetLastSessionSummary_SkipsStale(t *testing.T) {
store := &procMockFactStore{}
f := memory.NewFact("Session summary (test): old", memory.LevelDomain, "session-history", "")
f.IsStale = true
store.facts = append(store.facts, f)
result := GetLastSessionSummary(context.Background(), store)
assert.Empty(t, result)
}
func TestFormatDuration(t *testing.T) {
assert.Equal(t, "30s", formatDuration(30*time.Second))
assert.Equal(t, "5m", formatDuration(5*time.Minute))
assert.Equal(t, "2h15m", formatDuration(2*time.Hour+15*time.Minute))
}
func TestExtractTopicsFromEntries(t *testing.T) {
entries := []sqlite.InteractionEntry{
{ArgsJSON: `{"query":"architecture"}`},
{ArgsJSON: `{"content":"security review"}`},
{ArgsJSON: ``}, // empty
}
topics := extractTopicsFromEntries(entries)
assert.NotEmpty(t, topics)
}

View file

@ -0,0 +1,117 @@
package contextengine
import (
"context"
"sync"
"github.com/sentinel-community/gomcp/internal/domain/memory"
ctxdomain "github.com/sentinel-community/gomcp/internal/domain/context"
)
// StoreFactProvider adapts FactStore + HotCache to the FactProvider interface,
// bridging infrastructure storage with the context engine domain.
type StoreFactProvider struct {
store memory.FactStore
cache memory.HotCache
mu sync.Mutex
accessCounts map[string]int
}
// NewStoreFactProvider creates a FactProvider backed by FactStore and optional HotCache.
func NewStoreFactProvider(store memory.FactStore, cache memory.HotCache) *StoreFactProvider {
return &StoreFactProvider{
store: store,
cache: cache,
accessCounts: make(map[string]int),
}
}
// Verify interface compliance at compile time.
var _ ctxdomain.FactProvider = (*StoreFactProvider)(nil)
// GetRelevantFacts returns candidate facts for context injection.
// Uses keyword search from tool arguments + L0 facts as candidates.
func (p *StoreFactProvider) GetRelevantFacts(args map[string]interface{}) ([]*memory.Fact, error) {
ctx := context.Background()
// Always include L0 facts
l0Facts, err := p.GetL0Facts()
if err != nil {
return nil, err
}
// Extract query text from arguments for search
query := extractQueryFromArgs(args)
if query == "" {
return l0Facts, nil
}
// Search for additional relevant facts
searchResults, err := p.store.Search(ctx, query, 30)
if err != nil {
// Degrade gracefully — just return L0 facts
return l0Facts, nil
}
// Merge L0 + search results, deduplicating by ID
seen := make(map[string]bool, len(l0Facts))
merged := make([]*memory.Fact, 0, len(l0Facts)+len(searchResults))
for _, f := range l0Facts {
seen[f.ID] = true
merged = append(merged, f)
}
for _, f := range searchResults {
if !seen[f.ID] {
seen[f.ID] = true
merged = append(merged, f)
}
}
return merged, nil
}
// GetL0Facts returns all L0 (project-level) facts.
// Uses HotCache if available, falls back to store.
func (p *StoreFactProvider) GetL0Facts() ([]*memory.Fact, error) {
ctx := context.Background()
if p.cache != nil {
facts, err := p.cache.GetL0Facts(ctx)
if err == nil && len(facts) > 0 {
return facts, nil
}
}
return p.store.ListByLevel(ctx, memory.LevelProject)
}
// RecordAccess increments the access counter for a fact.
func (p *StoreFactProvider) RecordAccess(factID string) {
p.mu.Lock()
p.accessCounts[factID]++
p.mu.Unlock()
}
// extractQueryFromArgs builds a search query string from argument values.
func extractQueryFromArgs(args map[string]interface{}) string {
var parts []string
for _, v := range args {
if s, ok := v.(string); ok && s != "" {
parts = append(parts, s)
}
}
if len(parts) == 0 {
return ""
}
result := ""
for i, p := range parts {
if i > 0 {
result += " "
}
result += p
}
return result
}

View file

@ -0,0 +1,278 @@
package contextengine
import (
"context"
"fmt"
"sync"
"testing"
"github.com/sentinel-community/gomcp/internal/domain/memory"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Mock FactStore for provider tests ---
type mockFactStore struct {
mu sync.Mutex
facts map[string]*memory.Fact
searchFacts []*memory.Fact
searchErr error
levelFacts map[memory.HierLevel][]*memory.Fact
}
func newMockFactStore() *mockFactStore {
return &mockFactStore{
facts: make(map[string]*memory.Fact),
levelFacts: make(map[memory.HierLevel][]*memory.Fact),
}
}
func (m *mockFactStore) Add(_ context.Context, fact *memory.Fact) error {
m.mu.Lock()
defer m.mu.Unlock()
m.facts[fact.ID] = fact
m.levelFacts[fact.Level] = append(m.levelFacts[fact.Level], fact)
return nil
}
func (m *mockFactStore) Get(_ context.Context, id string) (*memory.Fact, error) {
m.mu.Lock()
defer m.mu.Unlock()
f, ok := m.facts[id]
if !ok {
return nil, fmt.Errorf("not found")
}
return f, nil
}
func (m *mockFactStore) Update(_ context.Context, fact *memory.Fact) error {
m.mu.Lock()
defer m.mu.Unlock()
m.facts[fact.ID] = fact
return nil
}
func (m *mockFactStore) Delete(_ context.Context, id string) error {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.facts, id)
return nil
}
func (m *mockFactStore) ListByDomain(_ context.Context, _ string, _ bool) ([]*memory.Fact, error) {
return nil, nil
}
func (m *mockFactStore) ListByLevel(_ context.Context, level memory.HierLevel) ([]*memory.Fact, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.levelFacts[level], nil
}
func (m *mockFactStore) ListDomains(_ context.Context) ([]string, error) {
return nil, nil
}
func (m *mockFactStore) GetStale(_ context.Context, _ bool) ([]*memory.Fact, error) {
return nil, nil
}
func (m *mockFactStore) Search(_ context.Context, _ string, _ int) ([]*memory.Fact, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.searchFacts, m.searchErr
}
func (m *mockFactStore) GetExpired(_ context.Context) ([]*memory.Fact, error) {
return nil, nil
}
func (m *mockFactStore) RefreshTTL(_ context.Context, _ string) error {
return nil
}
func (m *mockFactStore) TouchFact(_ context.Context, _ string) error { return nil }
func (m *mockFactStore) GetColdFacts(_ context.Context, _ int) ([]*memory.Fact, error) {
return nil, nil
}
func (m *mockFactStore) CompressFacts(_ context.Context, _ []string, _ string) (string, error) {
return "", nil
}
func (m *mockFactStore) Stats(_ context.Context) (*memory.FactStoreStats, error) {
return nil, nil
}
func (m *mockFactStore) ListGenes(_ context.Context) ([]*memory.Fact, error) { return nil, nil }
// --- Mock HotCache ---
type mockHotCache struct {
l0Facts []*memory.Fact
l0Err error
}
func (m *mockHotCache) GetL0Facts(_ context.Context) ([]*memory.Fact, error) {
return m.l0Facts, m.l0Err
}
func (m *mockHotCache) InvalidateFact(_ context.Context, _ string) error { return nil }
func (m *mockHotCache) WarmUp(_ context.Context, _ []*memory.Fact) error { return nil }
func (m *mockHotCache) Close() error { return nil }
// --- StoreFactProvider tests ---
func TestNewStoreFactProvider(t *testing.T) {
store := newMockFactStore()
provider := NewStoreFactProvider(store, nil)
require.NotNil(t, provider)
}
func TestStoreFactProvider_GetL0Facts_FromStore(t *testing.T) {
store := newMockFactStore()
f1 := memory.NewFact("L0 fact A", memory.LevelProject, "arch", "")
f2 := memory.NewFact("L0 fact B", memory.LevelProject, "process", "")
_ = store.Add(context.Background(), f1)
_ = store.Add(context.Background(), f2)
provider := NewStoreFactProvider(store, nil) // no cache
facts, err := provider.GetL0Facts()
require.NoError(t, err)
assert.Len(t, facts, 2)
}
func TestStoreFactProvider_GetL0Facts_FromCache(t *testing.T) {
store := newMockFactStore()
cacheFact := memory.NewFact("Cached L0", memory.LevelProject, "arch", "")
cache := &mockHotCache{l0Facts: []*memory.Fact{cacheFact}}
provider := NewStoreFactProvider(store, cache)
facts, err := provider.GetL0Facts()
require.NoError(t, err)
assert.Len(t, facts, 1)
assert.Equal(t, "Cached L0", facts[0].Content)
}
func TestStoreFactProvider_GetL0Facts_CacheFallbackToStore(t *testing.T) {
store := newMockFactStore()
storeFact := memory.NewFact("Store L0", memory.LevelProject, "arch", "")
_ = store.Add(context.Background(), storeFact)
cache := &mockHotCache{l0Facts: nil, l0Err: fmt.Errorf("cache miss")}
provider := NewStoreFactProvider(store, cache)
facts, err := provider.GetL0Facts()
require.NoError(t, err)
assert.Len(t, facts, 1)
assert.Equal(t, "Store L0", facts[0].Content)
}
func TestStoreFactProvider_GetRelevantFacts_NoQuery(t *testing.T) {
store := newMockFactStore()
l0 := memory.NewFact("L0 always included", memory.LevelProject, "arch", "")
_ = store.Add(context.Background(), l0)
provider := NewStoreFactProvider(store, nil)
// No string args → no query → only L0
facts, err := provider.GetRelevantFacts(map[string]interface{}{
"level": 0,
})
require.NoError(t, err)
assert.Len(t, facts, 1)
}
func TestStoreFactProvider_GetRelevantFacts_WithSearch(t *testing.T) {
store := newMockFactStore()
l0 := memory.NewFact("L0 architecture", memory.LevelProject, "arch", "")
searchResult := memory.NewFact("Found by search", memory.LevelDomain, "auth", "")
_ = store.Add(context.Background(), l0)
store.searchFacts = []*memory.Fact{searchResult}
provider := NewStoreFactProvider(store, nil)
facts, err := provider.GetRelevantFacts(map[string]interface{}{
"content": "authentication module",
})
require.NoError(t, err)
assert.Len(t, facts, 2) // L0 + search result
}
func TestStoreFactProvider_GetRelevantFacts_Deduplication(t *testing.T) {
store := newMockFactStore()
l0 := memory.NewFact("L0 architecture", memory.LevelProject, "arch", "")
_ = store.Add(context.Background(), l0)
// Search returns the same fact that's also L0
store.searchFacts = []*memory.Fact{l0}
provider := NewStoreFactProvider(store, nil)
facts, err := provider.GetRelevantFacts(map[string]interface{}{
"content": "architecture",
})
require.NoError(t, err)
assert.Len(t, facts, 1, "duplicate should be removed")
}
func TestStoreFactProvider_GetRelevantFacts_SearchError_GracefulDegradation(t *testing.T) {
store := newMockFactStore()
l0 := memory.NewFact("L0 fact", memory.LevelProject, "arch", "")
_ = store.Add(context.Background(), l0)
store.searchErr = fmt.Errorf("search broken")
provider := NewStoreFactProvider(store, nil)
facts, err := provider.GetRelevantFacts(map[string]interface{}{
"content": "test query",
})
require.NoError(t, err, "should degrade gracefully, not error")
assert.Len(t, facts, 1, "should still return L0 facts")
}
func TestStoreFactProvider_RecordAccess(t *testing.T) {
store := newMockFactStore()
provider := NewStoreFactProvider(store, nil)
provider.RecordAccess("fact-1")
provider.RecordAccess("fact-1")
provider.RecordAccess("fact-2")
provider.mu.Lock()
assert.Equal(t, 2, provider.accessCounts["fact-1"])
assert.Equal(t, 1, provider.accessCounts["fact-2"])
provider.mu.Unlock()
}
func TestExtractQueryFromArgs(t *testing.T) {
tests := []struct {
name string
args map[string]interface{}
want string
}{
{"nil", nil, ""},
{"empty", map[string]interface{}{}, ""},
{"no strings", map[string]interface{}{"level": 0, "flag": true}, ""},
{"single string", map[string]interface{}{"content": "hello"}, "hello"},
{"empty string", map[string]interface{}{"content": ""}, ""},
{"mixed", map[string]interface{}{"content": "hello", "level": 0, "domain": "arch"}, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractQueryFromArgs(tt.args)
if tt.name == "mixed" {
// Map iteration order is non-deterministic, just check non-empty
assert.NotEmpty(t, got)
} else {
if tt.want == "" {
assert.Empty(t, got)
} else {
assert.Contains(t, got, tt.want)
}
}
})
}
}

View file

@ -0,0 +1,94 @@
// Package lifecycle manages graceful shutdown with auto-save of session state,
// cache flush, and database closure.
package lifecycle
import (
"context"
"io"
"log"
"sync"
"time"
)
// ShutdownFunc is a function called during graceful shutdown.
// Name is used for logging. The function receives a context with a deadline.
type ShutdownFunc struct {
Name string
Fn func(ctx context.Context) error
}
// Manager orchestrates graceful shutdown of all resources.
type Manager struct {
mu sync.Mutex
hooks []ShutdownFunc
timeout time.Duration
done bool
}
// NewManager creates a new lifecycle Manager.
// Timeout is the maximum time allowed for all shutdown hooks to complete.
func NewManager(timeout time.Duration) *Manager {
if timeout <= 0 {
timeout = 10 * time.Second
}
return &Manager{
timeout: timeout,
}
}
// OnShutdown registers a shutdown hook. Hooks are called in LIFO order
// (last registered = first called), matching defer semantics.
func (m *Manager) OnShutdown(name string, fn func(ctx context.Context) error) {
m.mu.Lock()
defer m.mu.Unlock()
m.hooks = append(m.hooks, ShutdownFunc{Name: name, Fn: fn})
}
// OnClose registers an io.Closer as a shutdown hook.
func (m *Manager) OnClose(name string, c io.Closer) {
m.OnShutdown(name, func(_ context.Context) error {
return c.Close()
})
}
// Shutdown executes all registered hooks in reverse order (LIFO).
// It logs each step and any errors. Returns the first error encountered.
func (m *Manager) Shutdown() error {
m.mu.Lock()
if m.done {
m.mu.Unlock()
return nil
}
m.done = true
hooks := make([]ShutdownFunc, len(m.hooks))
copy(hooks, m.hooks)
m.mu.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), m.timeout)
defer cancel()
log.Printf("Graceful shutdown started (%d hooks, timeout %s)", len(hooks), m.timeout)
var firstErr error
// Execute in reverse order (LIFO).
for i := len(hooks) - 1; i >= 0; i-- {
h := hooks[i]
log.Printf(" shutdown: %s", h.Name)
if err := h.Fn(ctx); err != nil {
log.Printf(" shutdown %s: ERROR: %v", h.Name, err)
if firstErr == nil {
firstErr = err
}
}
}
log.Printf("Graceful shutdown complete")
return firstErr
}
// Done returns true if Shutdown has already been called.
func (m *Manager) Done() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.done
}

View file

@ -0,0 +1,125 @@
package lifecycle
import (
"context"
"errors"
"io"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewManager_Defaults(t *testing.T) {
m := NewManager(0)
require.NotNil(t, m)
assert.Equal(t, 10*time.Second, m.timeout)
assert.False(t, m.Done())
}
func TestNewManager_CustomTimeout(t *testing.T) {
m := NewManager(5 * time.Second)
assert.Equal(t, 5*time.Second, m.timeout)
}
func TestManager_Shutdown_LIFO(t *testing.T) {
m := NewManager(5 * time.Second)
order := []string{}
m.OnShutdown("first", func(_ context.Context) error {
order = append(order, "first")
return nil
})
m.OnShutdown("second", func(_ context.Context) error {
order = append(order, "second")
return nil
})
m.OnShutdown("third", func(_ context.Context) error {
order = append(order, "third")
return nil
})
err := m.Shutdown()
require.NoError(t, err)
assert.Equal(t, []string{"third", "second", "first"}, order)
assert.True(t, m.Done())
}
func TestManager_Shutdown_Idempotent(t *testing.T) {
m := NewManager(5 * time.Second)
count := 0
m.OnShutdown("counter", func(_ context.Context) error {
count++
return nil
})
_ = m.Shutdown()
_ = m.Shutdown()
_ = m.Shutdown()
assert.Equal(t, 1, count)
}
func TestManager_Shutdown_ReturnsFirstError(t *testing.T) {
m := NewManager(5 * time.Second)
errFirst := errors.New("first error")
errSecond := errors.New("second error")
m.OnShutdown("ok", func(_ context.Context) error { return nil })
m.OnShutdown("fail1", func(_ context.Context) error { return errFirst })
m.OnShutdown("fail2", func(_ context.Context) error { return errSecond })
// LIFO: fail2 runs first, then fail1, then ok.
err := m.Shutdown()
assert.Equal(t, errSecond, err)
}
func TestManager_Shutdown_ContinuesOnError(t *testing.T) {
m := NewManager(5 * time.Second)
reached := false
m.OnShutdown("will-run", func(_ context.Context) error {
reached = true
return nil
})
m.OnShutdown("will-fail", func(_ context.Context) error {
return errors.New("fail")
})
_ = m.Shutdown()
assert.True(t, reached, "hook after error should still run")
}
type mockCloser struct {
closed bool
}
func (m *mockCloser) Close() error {
m.closed = true
return nil
}
func TestManager_OnClose(t *testing.T) {
m := NewManager(5 * time.Second)
mc := &mockCloser{}
m.OnClose("mock-closer", mc)
_ = m.Shutdown()
assert.True(t, mc.closed)
}
func TestManager_OnClose_Interface(t *testing.T) {
m := NewManager(5 * time.Second)
// Verify OnClose accepts io.Closer interface.
var c io.Closer = &mockCloser{}
m.OnClose("io-closer", c)
err := m.Shutdown()
require.NoError(t, err)
}
func TestManager_EmptyShutdown(t *testing.T) {
m := NewManager(5 * time.Second)
err := m.Shutdown()
require.NoError(t, err)
assert.True(t, m.Done())
}

View file

@ -0,0 +1,75 @@
package lifecycle
import (
"crypto/rand"
"fmt"
"log"
"os"
)
// ShredDatabase irreversibly destroys a database file by overwriting
// its header with random bytes, making it unreadable without backup.
//
// For SQLite: overwrites first 100 bytes (header with magic bytes "SQLite format 3\000").
// For BoltDB: overwrites first 4096 bytes (two 4KB meta pages).
//
// WARNING: This operation is IRREVERSIBLE. Data is only recoverable from peer backup.
func ShredDatabase(dbPath string, headerSize int) error {
f, err := os.OpenFile(dbPath, os.O_WRONLY, 0)
if err != nil {
return fmt.Errorf("shred: open %s: %w", dbPath, err)
}
defer f.Close()
// Overwrite header with random bytes.
noise := make([]byte, headerSize)
if _, err := rand.Read(noise); err != nil {
return fmt.Errorf("shred: random: %w", err)
}
if _, err := f.WriteAt(noise, 0); err != nil {
return fmt.Errorf("shred: write %s: %w", dbPath, err)
}
// Force flush to disk.
if err := f.Sync(); err != nil {
return fmt.Errorf("shred: sync %s: %w", dbPath, err)
}
log.Printf("SHRED: %s header (%d bytes) destroyed", dbPath, headerSize)
return nil
}
// ShredSQLite shreds a SQLite database (100-byte header).
func ShredSQLite(dbPath string) error {
return ShredDatabase(dbPath, 100)
}
// ShredBoltDB shreds a BoltDB database (4096-byte meta pages).
func ShredBoltDB(dbPath string) error {
return ShredDatabase(dbPath, 4096)
}
// ShredAll shreds all known database files in the .rlm directory.
func ShredAll(rlmDir string) []error {
var errs []error
sqlitePath := rlmDir + "/memory/memory_bridge_v2.db"
if _, err := os.Stat(sqlitePath); err == nil {
if err := ShredSQLite(sqlitePath); err != nil {
errs = append(errs, err)
}
}
boltPath := rlmDir + "/cache.db"
if _, err := os.Stat(boltPath); err == nil {
if err := ShredBoltDB(boltPath); err != nil {
errs = append(errs, err)
}
}
if len(errs) == 0 {
log.Printf("SHRED: All databases destroyed in %s", rlmDir)
}
return errs
}

View file

@ -0,0 +1,75 @@
package lifecycle
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestShredSQLite(t *testing.T) {
dir := t.TempDir()
dbPath := filepath.Join(dir, "test.db")
// Create fake SQLite file with magic header.
header := []byte("SQLite format 3\x00")
data := make([]byte, 4096)
copy(data, header)
require.NoError(t, os.WriteFile(dbPath, data, 0644))
// Verify magic exists.
content, _ := os.ReadFile(dbPath)
assert.Equal(t, "SQLite format 3", string(content[:15]))
// Shred.
err := ShredSQLite(dbPath)
assert.NoError(t, err)
// Verify magic is destroyed.
content, _ = os.ReadFile(dbPath)
assert.NotEqual(t, "SQLite format 3", string(content[:15]),
"SQLite header should be shredded")
assert.Len(t, content, 4096, "file size should not change")
}
func TestShredBoltDB(t *testing.T) {
dir := t.TempDir()
dbPath := filepath.Join(dir, "cache.db")
// Create fake BoltDB file.
data := make([]byte, 8192) // 2 pages
copy(data, []byte("BOLT\x00\x00"))
require.NoError(t, os.WriteFile(dbPath, data, 0644))
err := ShredBoltDB(dbPath)
assert.NoError(t, err)
content, _ := os.ReadFile(dbPath)
assert.NotEqual(t, "BOLT", string(content[:4]),
"BoltDB header should be shredded")
}
func TestShredAll(t *testing.T) {
dir := t.TempDir()
// Create directory structure.
memDir := filepath.Join(dir, "memory")
os.MkdirAll(memDir, 0755)
// Create fake databases.
os.WriteFile(filepath.Join(memDir, "memory_bridge_v2.db"),
make([]byte, 4096), 0644)
os.WriteFile(filepath.Join(dir, "cache.db"),
make([]byte, 8192), 0644)
errs := ShredAll(dir)
assert.Empty(t, errs, "should shred without errors")
}
func TestShred_NonexistentFile(t *testing.T) {
err := ShredSQLite("/nonexistent/path/db.sqlite")
assert.Error(t, err, "should error on nonexistent file")
}

View file

@ -0,0 +1,70 @@
package orchestrator
import (
"encoding/json"
"fmt"
"os"
"time"
)
// JSONConfig is the v3.4 file-based configuration for the Orchestrator.
// Loaded from .rlm/config.json. Overrides compiled defaults.
type JSONConfig struct {
HeartbeatIntervalSec int `json:"heartbeat_interval_sec,omitempty"`
JitterPercent int `json:"jitter_percent,omitempty"`
EntropyThreshold float64 `json:"entropy_threshold,omitempty"`
MaxSyncBatchSize int `json:"max_sync_batch_size,omitempty"`
SynapseIntervalMult int `json:"synapse_interval_multiplier,omitempty"` // default: 12
}
// LoadConfigFromFile reads .rlm/config.json and returns a Config.
// Missing or invalid file → returns defaults silently.
func LoadConfigFromFile(path string) Config {
cfg := Config{
HeartbeatInterval: 5 * time.Minute,
JitterPercent: 30,
EntropyThreshold: 0.8,
MaxSyncBatchSize: 100,
}
data, err := os.ReadFile(path)
if err != nil {
return cfg // File not found → use defaults.
}
var jcfg JSONConfig
if err := json.Unmarshal(data, &jcfg); err != nil {
return cfg // Invalid JSON → use defaults.
}
if jcfg.HeartbeatIntervalSec > 0 {
cfg.HeartbeatInterval = time.Duration(jcfg.HeartbeatIntervalSec) * time.Second
}
if jcfg.JitterPercent > 0 && jcfg.JitterPercent <= 100 {
cfg.JitterPercent = jcfg.JitterPercent
}
if jcfg.EntropyThreshold > 0 && jcfg.EntropyThreshold <= 1.0 {
cfg.EntropyThreshold = jcfg.EntropyThreshold
}
if jcfg.MaxSyncBatchSize > 0 {
cfg.MaxSyncBatchSize = jcfg.MaxSyncBatchSize
}
return cfg
}
// WriteDefaultConfig writes a default config.json to the given path.
func WriteDefaultConfig(path string) error {
jcfg := JSONConfig{
HeartbeatIntervalSec: 300,
JitterPercent: 30,
EntropyThreshold: 0.8,
MaxSyncBatchSize: 100,
SynapseIntervalMult: 12,
}
data, err := json.MarshalIndent(jcfg, "", " ")
if err != nil {
return fmt.Errorf("marshal config: %w", err)
}
return os.WriteFile(path, data, 0o644)
}

View file

@ -0,0 +1,72 @@
package orchestrator
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLoadConfigFromFile_Defaults(t *testing.T) {
// Non-existent file → defaults.
cfg := LoadConfigFromFile("/nonexistent/config.json")
assert.Equal(t, 5*time.Minute, cfg.HeartbeatInterval)
assert.Equal(t, 30, cfg.JitterPercent)
assert.InDelta(t, 0.8, cfg.EntropyThreshold, 0.001)
assert.Equal(t, 100, cfg.MaxSyncBatchSize)
}
func TestLoadConfigFromFile_Custom(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.json")
err := os.WriteFile(path, []byte(`{
"heartbeat_interval_sec": 60,
"jitter_percent": 10,
"entropy_threshold": 0.5,
"max_sync_batch_size": 50
}`), 0o644)
require.NoError(t, err)
cfg := LoadConfigFromFile(path)
assert.Equal(t, 60*time.Second, cfg.HeartbeatInterval)
assert.Equal(t, 10, cfg.JitterPercent)
assert.InDelta(t, 0.5, cfg.EntropyThreshold, 0.001)
assert.Equal(t, 50, cfg.MaxSyncBatchSize)
}
func TestLoadConfigFromFile_InvalidJSON(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "bad.json")
os.WriteFile(path, []byte(`{invalid}`), 0o644)
cfg := LoadConfigFromFile(path)
// Should return defaults on invalid JSON.
assert.Equal(t, 5*time.Minute, cfg.HeartbeatInterval)
}
func TestWriteDefaultConfig(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.json")
err := WriteDefaultConfig(path)
require.NoError(t, err)
cfg := LoadConfigFromFile(path)
assert.Equal(t, 5*time.Minute, cfg.HeartbeatInterval)
assert.Equal(t, 30, cfg.JitterPercent)
}
func TestLoadConfigFromFile_PartialOverride(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "partial.json")
os.WriteFile(path, []byte(`{"heartbeat_interval_sec": 120}`), 0o644)
cfg := LoadConfigFromFile(path)
assert.Equal(t, 120*time.Second, cfg.HeartbeatInterval)
// Other fields should be defaults.
assert.Equal(t, 30, cfg.JitterPercent)
assert.InDelta(t, 0.8, cfg.EntropyThreshold, 0.001)
}

View file

@ -0,0 +1,806 @@
// Package orchestrator implements the DIP Heartbeat Orchestrator.
//
// The orchestrator runs a background loop with 4 modules:
// 1. Auto-Discovery — monitors configured peer endpoints for new Merkle-compatible nodes
// 2. Sync Manager — auto-syncs L0-L1 facts between trusted peers on changes
// 3. Stability Watchdog — monitors entropy and triggers apoptosis recovery
// 4. Jittered Heartbeat — randomizes intervals to avoid detection patterns
//
// The orchestrator works with domain-level components directly (not through MCP tools).
// It is started as a goroutine from main.go and runs until context cancellation.
package orchestrator
import (
"context"
"fmt"
"log"
"math/rand"
"sync"
"time"
"github.com/sentinel-community/gomcp/internal/domain/alert"
"github.com/sentinel-community/gomcp/internal/domain/entropy"
"github.com/sentinel-community/gomcp/internal/domain/memory"
"github.com/sentinel-community/gomcp/internal/domain/peer"
"github.com/sentinel-community/gomcp/internal/domain/synapse"
)
// Config holds orchestrator configuration.
type Config struct {
// HeartbeatInterval is the base interval between heartbeat cycles.
HeartbeatInterval time.Duration `json:"heartbeat_interval"`
// JitterPercent is the percentage of HeartbeatInterval to add/subtract randomly.
// e.g., 30 means ±30% jitter around the base interval.
JitterPercent int `json:"jitter_percent"`
// EntropyThreshold triggers apoptosis recovery when exceeded (0.0-1.0).
EntropyThreshold float64 `json:"entropy_threshold"`
// KnownPeers are pre-configured peer genome hashes for auto-discovery.
// Format: "node_name:genome_hash"
KnownPeers []string `json:"known_peers"`
// SyncOnChange triggers sync when new local facts are detected.
SyncOnChange bool `json:"sync_on_change"`
// MaxSyncBatchSize limits facts per sync payload.
MaxSyncBatchSize int `json:"max_sync_batch_size"`
}
// DefaultConfig returns sensible defaults.
func DefaultConfig() Config {
return Config{
HeartbeatInterval: 5 * time.Minute,
JitterPercent: 30,
EntropyThreshold: 0.95,
SyncOnChange: true,
MaxSyncBatchSize: 100,
}
}
// HeartbeatResult records what happened in one heartbeat cycle.
type HeartbeatResult struct {
Cycle int `json:"cycle"`
StartedAt time.Time `json:"started_at"`
Duration time.Duration `json:"duration"`
PeersDiscovered int `json:"peers_discovered"`
FactsSynced int `json:"facts_synced"`
EntropyLevel float64 `json:"entropy_level"`
ApoptosisTriggered bool `json:"apoptosis_triggered"`
GenomeIntact bool `json:"genome_intact"`
GenesHealed int `json:"genes_healed"`
FactsExpired int `json:"facts_expired"`
FactsArchived int `json:"facts_archived"`
SynapsesCreated int `json:"synapses_created"` // v3.4: Module 9
NextInterval time.Duration `json:"next_interval"`
Errors []string `json:"errors,omitempty"`
}
// Orchestrator runs the DIP heartbeat pipeline.
type Orchestrator struct {
mu sync.RWMutex
config Config
peerReg *peer.Registry
store memory.FactStore
synapseStore synapse.SynapseStore // v3.4: Module 9
alertBus *alert.Bus
running bool
cycle int
history []HeartbeatResult
lastSync time.Time
lastFactCount int
}
// New creates a new orchestrator.
func New(cfg Config, peerReg *peer.Registry, store memory.FactStore) *Orchestrator {
if cfg.HeartbeatInterval <= 0 {
cfg.HeartbeatInterval = 5 * time.Minute
}
if cfg.JitterPercent <= 0 || cfg.JitterPercent > 100 {
cfg.JitterPercent = 30
}
if cfg.EntropyThreshold <= 0 {
cfg.EntropyThreshold = 0.8
}
if cfg.MaxSyncBatchSize <= 0 {
cfg.MaxSyncBatchSize = 100
}
return &Orchestrator{
config: cfg,
peerReg: peerReg,
store: store,
history: make([]HeartbeatResult, 0, 64),
}
}
// NewWithAlerts creates an orchestrator with an alert bus for DIP-Watcher.
func NewWithAlerts(cfg Config, peerReg *peer.Registry, store memory.FactStore, bus *alert.Bus) *Orchestrator {
o := New(cfg, peerReg, store)
o.alertBus = bus
return o
}
// OrchestratorStatus is the v3.4 observability snapshot.
type OrchestratorStatus struct {
Running bool `json:"running"`
Cycle int `json:"cycle"`
Config Config `json:"config"`
LastResult *HeartbeatResult `json:"last_result,omitempty"`
HistorySize int `json:"history_size"`
HasSynapseStore bool `json:"has_synapse_store"`
}
// Status returns current orchestrator state (v3.4: observability).
func (o *Orchestrator) Status() OrchestratorStatus {
o.mu.RLock()
defer o.mu.RUnlock()
status := OrchestratorStatus{
Running: o.running,
Cycle: o.cycle,
Config: o.config,
HistorySize: len(o.history),
HasSynapseStore: o.synapseStore != nil,
}
if len(o.history) > 0 {
last := o.history[len(o.history)-1]
status.LastResult = &last
}
return status
}
// AlertBus returns the alert bus (may be nil).
func (o *Orchestrator) AlertBus() *alert.Bus {
return o.alertBus
}
// Start begins the heartbeat loop. Blocks until context is cancelled.
func (o *Orchestrator) Start(ctx context.Context) {
o.mu.Lock()
o.running = true
o.mu.Unlock()
defer func() {
o.mu.Lock()
o.running = false
o.mu.Unlock()
}()
log.Printf("orchestrator: started (interval=%s, jitter=±%d%%, entropy_threshold=%.2f)",
o.config.HeartbeatInterval, o.config.JitterPercent, o.config.EntropyThreshold)
for {
result := o.heartbeat(ctx)
o.mu.Lock()
o.history = append(o.history, result)
// Keep last 64 results.
if len(o.history) > 64 {
o.history = o.history[len(o.history)-64:]
}
o.mu.Unlock()
if result.ApoptosisTriggered {
log.Printf("orchestrator: apoptosis triggered at cycle %d, entropy=%.4f",
result.Cycle, result.EntropyLevel)
}
// Jittered sleep.
select {
case <-ctx.Done():
log.Printf("orchestrator: stopped after %d cycles", o.cycle)
return
case <-time.After(result.NextInterval):
}
}
}
// heartbeat executes one cycle of the pipeline.
func (o *Orchestrator) heartbeat(ctx context.Context) HeartbeatResult {
o.mu.Lock()
o.cycle++
cycle := o.cycle
o.mu.Unlock()
start := time.Now()
result := HeartbeatResult{
Cycle: cycle,
StartedAt: start,
}
// --- Module 1: Auto-Discovery ---
discovered := o.autoDiscover(ctx)
result.PeersDiscovered = discovered
// --- Module 2: Stability Watchdog (genome + entropy check) ---
genomeOK, entropyLevel := o.stabilityCheck(ctx, &result)
result.GenomeIntact = genomeOK
result.EntropyLevel = entropyLevel
// --- Module 3: Sync Manager ---
if genomeOK && !result.ApoptosisTriggered {
synced := o.syncManager(ctx, &result)
result.FactsSynced = synced
}
// --- Module 4: Self-Healing (auto-restore missing genes) ---
healed := o.selfHeal(ctx, &result)
result.GenesHealed = healed
// --- Module 5: Memory Hygiene (expire stale, archive old) ---
expired, archived := o.memoryHygiene(ctx, &result)
result.FactsExpired = expired
result.FactsArchived = archived
// --- Module 6: State Persistence (auto-snapshot) ---
o.statePersistence(ctx, &result)
// --- Module 7: Jittered interval ---
result.NextInterval = o.jitteredInterval()
result.Duration = time.Since(start)
// --- Module 8: DIP-Watcher (proactive alert generation) ---
o.dipWatcher(&result)
// --- Module 9: Synapse Scanner (v3.4) ---
if o.synapseStore != nil && cycle%12 == 0 {
created := o.synapseScanner(ctx, &result)
result.SynapsesCreated = created
}
log.Printf("orchestrator: cycle=%d peers=%d synced=%d healed=%d expired=%d archived=%d synapses=%d entropy=%.4f genome=%v next=%s",
cycle, discovered, result.FactsSynced, healed, expired, archived, result.SynapsesCreated, entropyLevel, genomeOK, result.NextInterval)
return result
}
// dipWatcher is Module 8: proactive monitoring that generates alerts
// based on heartbeat metrics. Feeds the TUI alert panel.
func (o *Orchestrator) dipWatcher(result *HeartbeatResult) {
if o.alertBus == nil {
return
}
cycle := result.Cycle
// --- Entropy monitoring ---
if result.EntropyLevel > 0.9 {
o.alertBus.Emit(alert.New(alert.SourceEntropy, alert.SeverityCritical,
fmt.Sprintf("CRITICAL entropy: %.4f (threshold: 0.90)", result.EntropyLevel), cycle).
WithValue(result.EntropyLevel))
} else if result.EntropyLevel > 0.7 {
o.alertBus.Emit(alert.New(alert.SourceEntropy, alert.SeverityWarning,
fmt.Sprintf("Elevated entropy: %.4f", result.EntropyLevel), cycle).
WithValue(result.EntropyLevel))
}
// --- Genome integrity ---
if !result.GenomeIntact {
o.alertBus.Emit(alert.New(alert.SourceGenome, alert.SeverityCritical,
"Genome integrity FAILED — Merkle root mismatch", cycle))
}
if result.ApoptosisTriggered {
o.alertBus.Emit(alert.New(alert.SourceSystem, alert.SeverityCritical,
"APOPTOSIS triggered — emergency genome preservation", cycle))
}
// --- Self-healing events ---
if result.GenesHealed > 0 {
o.alertBus.Emit(alert.New(alert.SourceGenome, alert.SeverityWarning,
fmt.Sprintf("Self-healed %d missing genes", result.GenesHealed), cycle))
}
// --- Memory hygiene ---
if result.FactsExpired > 5 {
o.alertBus.Emit(alert.New(alert.SourceMemory, alert.SeverityWarning,
fmt.Sprintf("Memory cleanup: %d expired, %d archived",
result.FactsExpired, result.FactsArchived), cycle))
}
// --- Heartbeat health ---
if result.Duration > 2*o.config.HeartbeatInterval {
o.alertBus.Emit(alert.New(alert.SourceSystem, alert.SeverityWarning,
fmt.Sprintf("Slow heartbeat: %s (expected <%s)",
result.Duration, o.config.HeartbeatInterval), cycle))
}
// --- Peer discovery ---
if result.PeersDiscovered > 0 {
o.alertBus.Emit(alert.New(alert.SourcePeer, alert.SeverityInfo,
fmt.Sprintf("Discovered %d new peer(s)", result.PeersDiscovered), cycle))
}
// --- Sync events ---
if result.FactsSynced > 0 {
o.alertBus.Emit(alert.New(alert.SourcePeer, alert.SeverityInfo,
fmt.Sprintf("Synced %d facts to peers", result.FactsSynced), cycle))
}
// --- Status heartbeat (every cycle) ---
if len(result.Errors) == 0 && result.GenomeIntact {
o.alertBus.Emit(alert.New(alert.SourceWatcher, alert.SeverityInfo,
fmt.Sprintf("Heartbeat OK (cycle=%d, entropy=%.4f)", cycle, result.EntropyLevel), cycle))
}
}
// autoDiscover checks configured peers and initiates handshakes.
func (o *Orchestrator) autoDiscover(ctx context.Context) int {
localHash := memory.CompiledGenomeHash()
discovered := 0
for _, peerSpec := range o.config.KnownPeers {
// Parse "node_name:genome_hash" format.
nodeName, hash := parsePeerSpec(peerSpec)
if hash == "" {
continue
}
// Skip if already trusted.
// Use hash as pseudo peer_id for discovery.
peerID := "discovered_" + hash[:12]
if o.peerReg.IsTrusted(peerID) {
o.peerReg.TouchPeer(peerID)
continue
}
req := peer.HandshakeRequest{
FromPeerID: peerID,
FromNode: nodeName,
GenomeHash: hash,
Timestamp: time.Now().Unix(),
}
resp, err := o.peerReg.ProcessHandshake(req, localHash)
if err != nil {
continue
}
if resp.Match {
discovered++
log.Printf("orchestrator: discovered trusted peer %s [%s]", nodeName, peerID)
}
}
// Check for timed-out peers.
genes, _ := o.store.ListGenes(ctx)
syncFacts := genesToSyncFacts(genes)
backups := o.peerReg.CheckTimeouts(syncFacts)
if len(backups) > 0 {
log.Printf("orchestrator: %d peers timed out, gene backups created", len(backups))
}
return discovered
}
// stabilityCheck verifies genome integrity and measures entropy.
func (o *Orchestrator) stabilityCheck(ctx context.Context, result *HeartbeatResult) (bool, float64) {
// Check genome integrity via gene count.
genes, err := o.store.ListGenes(ctx)
if err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("list genes: %v", err))
return false, 0
}
genomeOK := len(genes) >= len(memory.HardcodedGenes)
// Compute entropy on USER-CREATED facts only.
// System facts (genes, watchdog, heartbeat, session-history) are excluded —
// their entropy is irrelevant for anomaly detection.
l0Facts, _ := o.store.ListByLevel(ctx, memory.LevelProject)
l1Facts, _ := o.store.ListByLevel(ctx, memory.LevelDomain)
var dynamicContent string
for _, f := range append(l0Facts, l1Facts...) {
if f.IsGene {
continue
}
// Only include user-created content — source "manual" (add_fact) or "mcp".
if f.Source != "manual" && f.Source != "mcp" {
continue
}
dynamicContent += f.Content + " "
}
// No dynamic facts = healthy (entropy 0).
if dynamicContent == "" {
return genomeOK, 0
}
entropyLevel := entropy.ShannonEntropy(dynamicContent)
// Normalize entropy to 0-1 range (typical text: 3-5 bits/char).
normalizedEntropy := entropyLevel / 5.0
if normalizedEntropy > 1.0 {
normalizedEntropy = 1.0
}
if normalizedEntropy >= o.config.EntropyThreshold {
result.ApoptosisTriggered = true
currentHash := memory.CompiledGenomeHash()
recoveryMarker := memory.NewFact(
fmt.Sprintf("[WATCHDOG_RECOVERY] genome_hash=%s entropy=%.4f cycle=%d",
currentHash, normalizedEntropy, result.Cycle),
memory.LevelProject,
"recovery",
"watchdog",
)
recoveryMarker.Source = "watchdog"
_ = o.store.Add(ctx, recoveryMarker)
}
return genomeOK, normalizedEntropy
}
// syncManager exports facts to all trusted peers.
func (o *Orchestrator) syncManager(ctx context.Context, result *HeartbeatResult) int {
// Check if we have new facts since last sync.
l0Facts, err := o.store.ListByLevel(ctx, memory.LevelProject)
if err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("list L0: %v", err))
return 0
}
l1Facts, err := o.store.ListByLevel(ctx, memory.LevelDomain)
if err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("list L1: %v", err))
return 0
}
totalFacts := len(l0Facts) + len(l1Facts)
o.mu.RLock()
lastCount := o.lastFactCount
o.mu.RUnlock()
// Skip sync if no changes and sync_on_change is enabled.
if o.config.SyncOnChange && totalFacts == lastCount && !o.lastSync.IsZero() {
return 0
}
// Build sync payload.
allFacts := append(l0Facts, l1Facts...)
syncFacts := make([]peer.SyncFact, 0, len(allFacts))
for _, f := range allFacts {
if f.IsStale || f.IsArchived {
continue
}
syncFacts = append(syncFacts, peer.SyncFact{
ID: f.ID,
Content: f.Content,
Level: int(f.Level),
Domain: f.Domain,
Module: f.Module,
IsGene: f.IsGene,
Source: f.Source,
CreatedAt: f.CreatedAt,
})
}
if len(syncFacts) > o.config.MaxSyncBatchSize {
syncFacts = syncFacts[:o.config.MaxSyncBatchSize]
}
// Record sync readiness for all trusted peers.
trustedPeers := o.peerReg.ListPeers()
synced := 0
for _, p := range trustedPeers {
if p.Trust == peer.TrustVerified {
_ = o.peerReg.RecordSync(p.PeerID, len(syncFacts))
synced += len(syncFacts)
}
}
o.mu.Lock()
o.lastSync = time.Now()
o.lastFactCount = totalFacts
o.mu.Unlock()
return synced
}
// jitteredInterval returns the next heartbeat interval with random jitter.
func (o *Orchestrator) jitteredInterval() time.Duration {
base := o.config.HeartbeatInterval
jitterRange := time.Duration(float64(base) * float64(o.config.JitterPercent) / 100.0)
jitter := time.Duration(rand.Int63n(int64(jitterRange)*2)) - jitterRange
interval := base + jitter
if interval < 10*time.Millisecond {
interval = 10 * time.Millisecond
}
return interval
}
// IsRunning returns whether the orchestrator is active.
func (o *Orchestrator) IsRunning() bool {
o.mu.RLock()
defer o.mu.RUnlock()
return o.running
}
// Stats returns current orchestrator status.
func (o *Orchestrator) Stats() map[string]interface{} {
o.mu.RLock()
defer o.mu.RUnlock()
stats := map[string]interface{}{
"running": o.running,
"total_cycles": o.cycle,
"config": o.config,
"last_sync": o.lastSync,
"last_fact_count": o.lastFactCount,
"history_size": len(o.history),
}
if len(o.history) > 0 {
last := o.history[len(o.history)-1]
stats["last_heartbeat"] = last
}
return stats
}
// History returns recent heartbeat results.
func (o *Orchestrator) History() []HeartbeatResult {
o.mu.RLock()
defer o.mu.RUnlock()
result := make([]HeartbeatResult, len(o.history))
copy(result, o.history)
return result
}
// selfHeal checks for missing hardcoded genes and re-bootstraps them.
// Returns the number of genes restored.
func (o *Orchestrator) selfHeal(ctx context.Context, result *HeartbeatResult) int {
genes, err := o.store.ListGenes(ctx)
if err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("self-heal list genes: %v", err))
return 0
}
// Check if all hardcoded genes are present.
if len(genes) >= len(memory.HardcodedGenes) {
return 0 // All present, nothing to heal.
}
// Some genes missing — re-bootstrap.
healed, err := memory.BootstrapGenome(ctx, o.store, "")
if err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("self-heal bootstrap: %v", err))
return 0
}
if healed > 0 {
log.Printf("orchestrator: self-healed %d missing genes", healed)
}
return healed
}
// memoryHygiene processes expired TTL facts and archives stale ones.
// Returns (expired_count, archived_count).
func (o *Orchestrator) memoryHygiene(ctx context.Context, result *HeartbeatResult) (int, int) {
expired := 0
archived := 0
// Step 1: Mark expired TTL facts as stale.
expiredFacts, err := o.store.GetExpired(ctx)
if err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("hygiene get-expired: %v", err))
return 0, 0
}
for _, f := range expiredFacts {
if f.IsGene {
continue // Never expire genes.
}
f.IsStale = true
if err := o.store.Update(ctx, f); err == nil {
expired++
}
}
// Step 2: Archive facts that have been stale for a while.
staleFacts, err := o.store.GetStale(ctx, false) // exclude already-archived
if err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("hygiene get-stale: %v", err))
return expired, 0
}
staleThreshold := time.Now().Add(-24 * time.Hour) // Archive if stale > 24h.
for _, f := range staleFacts {
if f.IsGene {
continue // Never archive genes.
}
if f.UpdatedAt.Before(staleThreshold) {
f.IsArchived = true
if err := o.store.Update(ctx, f); err == nil {
archived++
}
}
}
if expired > 0 || archived > 0 {
log.Printf("orchestrator: hygiene — expired %d facts, archived %d stale facts", expired, archived)
}
return expired, archived
}
// statePersistence writes a heartbeat snapshot every N cycles.
// This creates a persistent breadcrumb trail that survives restarts.
func (o *Orchestrator) statePersistence(ctx context.Context, result *HeartbeatResult) {
// Snapshot every 50 cycles (avoids memory inflation in fast-heartbeat TUI mode).
if result.Cycle%50 != 0 {
return
}
snapshot := memory.NewFact(
fmt.Sprintf("[HEARTBEAT_SNAPSHOT] cycle=%d genome=%v entropy=%.4f peers=%d synced=%d healed=%d",
result.Cycle, result.GenomeIntact, result.EntropyLevel,
result.PeersDiscovered, result.FactsSynced, result.GenesHealed),
memory.LevelProject,
"orchestrator",
"heartbeat",
)
snapshot.Source = "heartbeat"
if err := o.store.Add(ctx, snapshot); err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("snapshot: %v", err))
}
}
// --- Helpers ---
func parsePeerSpec(spec string) (nodeName, hash string) {
for i, c := range spec {
if c == ':' {
return spec[:i], spec[i+1:]
}
}
return "unknown", spec
}
func genesToSyncFacts(genes []*memory.Fact) []peer.SyncFact {
facts := make([]peer.SyncFact, 0, len(genes))
for _, g := range genes {
facts = append(facts, peer.SyncFact{
ID: g.ID,
Content: g.Content,
Level: int(g.Level),
Domain: g.Domain,
IsGene: g.IsGene,
Source: g.Source,
})
}
return facts
}
// SetSynapseStore enables Module 9 (Synapse Scanner) at runtime.
func (o *Orchestrator) SetSynapseStore(store synapse.SynapseStore) {
o.mu.Lock()
defer o.mu.Unlock()
o.synapseStore = store
}
// synapseScanner is Module 9: automatic semantic link discovery.
// Scans active facts and proposes PENDING synapse connections based on
// domain overlap and keyword similarity. Threshold: 0.85.
func (o *Orchestrator) synapseScanner(ctx context.Context, result *HeartbeatResult) int {
// Get all non-stale, non-archived facts.
allFacts := make([]*memory.Fact, 0)
for level := 0; level <= 3; level++ {
hl, ok := memory.HierLevelFromInt(level)
if !ok {
continue
}
facts, err := o.store.ListByLevel(ctx, hl)
if err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("synapse_scan L%d: %v", level, err))
continue
}
for _, f := range facts {
if !f.IsGene && !f.IsStale && !f.IsArchived {
allFacts = append(allFacts, f)
}
}
}
if len(allFacts) < 2 {
return 0
}
created := 0
// Compare pairs: O(n²) but fact count is small (typically <500).
for i := 0; i < len(allFacts)-1 && i < 200; i++ {
for j := i + 1; j < len(allFacts) && j < 200; j++ {
a, b := allFacts[i], allFacts[j]
confidence := synapseSimilarity(a, b)
if confidence < 0.85 {
continue
}
// Check if synapse already exists.
exists, err := o.synapseStore.Exists(ctx, a.ID, b.ID)
if err != nil || exists {
continue
}
_, err = o.synapseStore.Create(ctx, a.ID, b.ID, confidence)
if err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("synapse_create: %v", err))
continue
}
created++
}
}
if created > 0 && o.alertBus != nil {
o.alertBus.Emit(alert.New(
alert.SourceMemory,
alert.SeverityInfo,
fmt.Sprintf("Synapse Scanner: created %d new bridges", created),
result.Cycle,
))
}
return created
}
// synapseSimilarity computes a confidence score between two facts.
// Returns 0.01.0 based on domain match and keyword overlap.
func synapseSimilarity(a, b *memory.Fact) float64 {
score := 0.0
// Same domain → strong signal.
if a.Domain != "" && a.Domain == b.Domain {
score += 0.50
}
// Same module → additional signal.
if a.Module != "" && a.Module == b.Module {
score += 0.20
}
// Keyword overlap (words > 3 chars).
wordsA := tokenize(a.Content)
wordsB := tokenize(b.Content)
if len(wordsA) > 0 && len(wordsB) > 0 {
overlap := 0
for w := range wordsA {
if wordsB[w] {
overlap++
}
}
total := len(wordsA)
if len(wordsB) < total {
total = len(wordsB)
}
if total > 0 {
score += 0.30 * float64(overlap) / float64(total)
}
}
if score > 1.0 {
score = 1.0
}
return score
}
// tokenize splits text into unique lowercase words (>3 chars).
func tokenize(text string) map[string]bool {
words := make(map[string]bool)
current := make([]byte, 0, 32)
for i := 0; i < len(text); i++ {
c := text[i]
if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' {
if c >= 'A' && c <= 'Z' {
c += 32 // toLower
}
current = append(current, c)
} else {
if len(current) > 3 {
words[string(current)] = true
}
current = current[:0]
}
}
if len(current) > 3 {
words[string(current)] = true
}
return words
}

View file

@ -0,0 +1,318 @@
package orchestrator
import (
"context"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/sentinel-community/gomcp/internal/domain/memory"
"github.com/sentinel-community/gomcp/internal/domain/peer"
)
func newTestOrchestrator(t *testing.T, cfg Config) (*Orchestrator, *inMemoryStore) {
t.Helper()
store := newInMemoryStore()
peerReg := peer.NewRegistry("test-node", 30*time.Minute)
// Bootstrap genes into store.
ctx := context.Background()
for _, gd := range memory.HardcodedGenes {
gene := memory.NewGene(gd.Content, gd.Domain)
gene.ID = gd.ID
_ = store.Add(ctx, gene)
}
return New(cfg, peerReg, store), store
}
func TestDefaultConfig(t *testing.T) {
cfg := DefaultConfig()
assert.Equal(t, 5*time.Minute, cfg.HeartbeatInterval)
assert.Equal(t, 30, cfg.JitterPercent)
assert.Equal(t, 0.95, cfg.EntropyThreshold)
assert.True(t, cfg.SyncOnChange)
assert.Equal(t, 100, cfg.MaxSyncBatchSize)
}
func TestNew_WithDefaults(t *testing.T) {
o, _ := newTestOrchestrator(t, DefaultConfig())
assert.False(t, o.IsRunning())
assert.Equal(t, 0, o.cycle)
}
func TestHeartbeat_SingleCycle(t *testing.T) {
cfg := DefaultConfig()
cfg.HeartbeatInterval = 100 * time.Millisecond
cfg.EntropyThreshold = 1.1 // Above max normalized — won't trigger apoptosis.
o, _ := newTestOrchestrator(t, cfg)
result := o.heartbeat(context.Background())
assert.Equal(t, 1, result.Cycle)
assert.True(t, result.GenomeIntact, "Genome must be intact with all hardcoded genes")
assert.GreaterOrEqual(t, result.Duration, time.Duration(0))
assert.Greater(t, result.NextInterval, time.Duration(0))
}
func TestHeartbeat_GenomeIntact(t *testing.T) {
cfg := DefaultConfig()
cfg.EntropyThreshold = 1.1 // Above max normalized — won't trigger apoptosis.
o, _ := newTestOrchestrator(t, cfg)
result := o.heartbeat(context.Background())
assert.True(t, result.GenomeIntact)
assert.False(t, result.ApoptosisTriggered)
assert.Empty(t, result.Errors)
}
func TestAutoDiscover_ConfiguredPeers(t *testing.T) {
cfg := DefaultConfig()
cfg.KnownPeers = []string{
"node-alpha:" + memory.CompiledGenomeHash(), // matching hash
"node-evil:deadbeefdeadbeef", // non-matching
}
o, _ := newTestOrchestrator(t, cfg)
discovered := o.autoDiscover(context.Background())
assert.Equal(t, 1, discovered, "Only matching genome should be discovered")
// Second call: already trusted, should not re-discover.
discovered2 := o.autoDiscover(context.Background())
assert.Equal(t, 0, discovered2, "Already trusted peer should not be re-discovered")
}
func TestSyncManager_NoTrustedPeers(t *testing.T) {
o, _ := newTestOrchestrator(t, DefaultConfig())
result := HeartbeatResult{}
synced := o.syncManager(context.Background(), &result)
assert.Equal(t, 0, synced, "No trusted peers = no sync")
}
func TestSyncManager_WithTrustedPeer(t *testing.T) {
cfg := DefaultConfig()
cfg.KnownPeers = []string{"peer:" + memory.CompiledGenomeHash()}
o, _ := newTestOrchestrator(t, cfg)
// Discover peer first.
o.autoDiscover(context.Background())
result := HeartbeatResult{}
synced := o.syncManager(context.Background(), &result)
assert.Greater(t, synced, 0, "Trusted peer should receive synced facts")
}
func TestSyncManager_SkipWhenNoChanges(t *testing.T) {
cfg := DefaultConfig()
cfg.SyncOnChange = true
cfg.KnownPeers = []string{"peer:" + memory.CompiledGenomeHash()}
o, _ := newTestOrchestrator(t, cfg)
o.autoDiscover(context.Background())
result := HeartbeatResult{}
// First sync.
synced1 := o.syncManager(context.Background(), &result)
assert.Greater(t, synced1, 0)
// Second sync — no changes.
synced2 := o.syncManager(context.Background(), &result)
assert.Equal(t, 0, synced2, "No new facts = skip sync")
}
func TestJitteredInterval(t *testing.T) {
cfg := DefaultConfig()
cfg.HeartbeatInterval = 1 * time.Second
cfg.JitterPercent = 50
o, _ := newTestOrchestrator(t, cfg)
intervals := make(map[time.Duration]bool)
for i := 0; i < 20; i++ {
interval := o.jitteredInterval()
intervals[interval] = true
// Must be between 500ms and 1500ms.
assert.GreaterOrEqual(t, interval, 500*time.Millisecond)
assert.LessOrEqual(t, interval, 1500*time.Millisecond)
}
// With 20 samples and 50% jitter, we should get some variety.
assert.Greater(t, len(intervals), 1, "Jitter should produce varied intervals")
}
func TestStartAndStop(t *testing.T) {
cfg := DefaultConfig()
cfg.HeartbeatInterval = 50 * time.Millisecond
cfg.JitterPercent = 10
o, _ := newTestOrchestrator(t, cfg)
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
assert.False(t, o.IsRunning())
go o.Start(ctx)
time.Sleep(50 * time.Millisecond)
assert.True(t, o.IsRunning())
<-ctx.Done()
time.Sleep(100 * time.Millisecond)
assert.False(t, o.IsRunning())
assert.GreaterOrEqual(t, o.cycle, 1, "At least one cycle should have completed")
}
func TestStats(t *testing.T) {
cfg := DefaultConfig()
cfg.HeartbeatInterval = 50 * time.Millisecond
cfg.JitterPercent = 10
o, _ := newTestOrchestrator(t, cfg)
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
defer cancel()
go o.Start(ctx)
time.Sleep(100 * time.Millisecond)
stats := o.Stats()
assert.True(t, stats["running"].(bool) || stats["total_cycles"].(int) >= 1)
assert.GreaterOrEqual(t, stats["total_cycles"].(int), 1)
}
func TestHistory(t *testing.T) {
cfg := DefaultConfig()
cfg.HeartbeatInterval = 30 * time.Millisecond
cfg.JitterPercent = 10
o, _ := newTestOrchestrator(t, cfg)
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go o.Start(ctx)
<-ctx.Done()
time.Sleep(50 * time.Millisecond)
history := o.History()
assert.GreaterOrEqual(t, len(history), 2, "Should have at least 2 cycles")
assert.Equal(t, 1, history[0].Cycle)
}
func TestParsePeerSpec(t *testing.T) {
tests := []struct {
spec string
wantNode string
wantHash string
}{
{"alpha:abc123", "alpha", "abc123"},
{"abc123", "unknown", "abc123"},
{"node-1:hash:with:colons", "node-1", "hash:with:colons"},
}
for _, tt := range tests {
node, hash := parsePeerSpec(tt.spec)
assert.Equal(t, tt.wantNode, node, "spec=%s", tt.spec)
assert.Equal(t, tt.wantHash, hash, "spec=%s", tt.spec)
}
}
// --- In-memory FactStore for testing ---
type inMemoryStore struct {
facts map[string]*memory.Fact
}
func newInMemoryStore() *inMemoryStore {
return &inMemoryStore{facts: make(map[string]*memory.Fact)}
}
func (s *inMemoryStore) Add(_ context.Context, fact *memory.Fact) error {
if _, exists := s.facts[fact.ID]; exists {
return fmt.Errorf("duplicate: %s", fact.ID)
}
f := *fact
s.facts[fact.ID] = &f
return nil
}
func (s *inMemoryStore) Get(_ context.Context, id string) (*memory.Fact, error) {
f, ok := s.facts[id]
if !ok {
return nil, fmt.Errorf("not found: %s", id)
}
return f, nil
}
func (s *inMemoryStore) Update(_ context.Context, fact *memory.Fact) error {
s.facts[fact.ID] = fact
return nil
}
func (s *inMemoryStore) Delete(_ context.Context, id string) error {
delete(s.facts, id)
return nil
}
func (s *inMemoryStore) ListByDomain(_ context.Context, domain string, _ bool) ([]*memory.Fact, error) {
var result []*memory.Fact
for _, f := range s.facts {
if f.Domain == domain {
result = append(result, f)
}
}
return result, nil
}
func (s *inMemoryStore) ListByLevel(_ context.Context, level memory.HierLevel) ([]*memory.Fact, error) {
var result []*memory.Fact
for _, f := range s.facts {
if f.Level == level {
result = append(result, f)
}
}
return result, nil
}
func (s *inMemoryStore) ListDomains(_ context.Context) ([]string, error) {
domains := make(map[string]bool)
for _, f := range s.facts {
domains[f.Domain] = true
}
result := make([]string, 0, len(domains))
for d := range domains {
result = append(result, d)
}
return result, nil
}
func (s *inMemoryStore) GetStale(_ context.Context, _ bool) ([]*memory.Fact, error) {
return nil, nil
}
func (s *inMemoryStore) Search(_ context.Context, _ string, _ int) ([]*memory.Fact, error) {
return nil, nil
}
func (s *inMemoryStore) ListGenes(_ context.Context) ([]*memory.Fact, error) {
var result []*memory.Fact
for _, f := range s.facts {
if f.IsGene {
result = append(result, f)
}
}
return result, nil
}
func (s *inMemoryStore) GetExpired(_ context.Context) ([]*memory.Fact, error) {
return nil, nil
}
func (s *inMemoryStore) RefreshTTL(_ context.Context, _ string) error {
return nil
}
func (s *inMemoryStore) TouchFact(_ context.Context, _ string) error { return nil }
func (s *inMemoryStore) GetColdFacts(_ context.Context, _ int) ([]*memory.Fact, error) {
return nil, nil
}
func (s *inMemoryStore) CompressFacts(_ context.Context, _ []string, _ string) (string, error) {
return "", nil
}
func (s *inMemoryStore) Stats(_ context.Context) (*memory.FactStoreStats, error) {
return &memory.FactStoreStats{TotalFacts: len(s.facts)}, nil
}

View file

@ -0,0 +1,64 @@
// Package resources provides MCP resource implementations.
package resources
import (
"context"
"encoding/json"
"fmt"
"github.com/sentinel-community/gomcp/internal/domain/memory"
"github.com/sentinel-community/gomcp/internal/domain/session"
)
// Provider serves MCP resources (rlm://state, rlm://facts, rlm://stats).
type Provider struct {
factStore memory.FactStore
stateStore session.StateStore
}
// NewProvider creates a new resource Provider.
func NewProvider(factStore memory.FactStore, stateStore session.StateStore) *Provider {
return &Provider{
factStore: factStore,
stateStore: stateStore,
}
}
// GetState returns the current cognitive state for a session as JSON.
func (p *Provider) GetState(ctx context.Context, sessionID string) (string, error) {
state, _, err := p.stateStore.Load(ctx, sessionID, nil)
if err != nil {
return "", fmt.Errorf("load state: %w", err)
}
data, err := json.MarshalIndent(state, "", " ")
if err != nil {
return "", fmt.Errorf("marshal state: %w", err)
}
return string(data), nil
}
// GetFacts returns L0 facts as JSON.
func (p *Provider) GetFacts(ctx context.Context) (string, error) {
facts, err := p.factStore.ListByLevel(ctx, memory.LevelProject)
if err != nil {
return "", fmt.Errorf("list L0 facts: %w", err)
}
data, err := json.MarshalIndent(facts, "", " ")
if err != nil {
return "", fmt.Errorf("marshal facts: %w", err)
}
return string(data), nil
}
// GetStats returns fact store statistics as JSON.
func (p *Provider) GetStats(ctx context.Context) (string, error) {
stats, err := p.factStore.Stats(ctx)
if err != nil {
return "", fmt.Errorf("get stats: %w", err)
}
data, err := json.MarshalIndent(stats, "", " ")
if err != nil {
return "", fmt.Errorf("marshal stats: %w", err)
}
return string(data), nil
}

View file

@ -0,0 +1,150 @@
package resources
import (
"context"
"encoding/json"
"testing"
"github.com/sentinel-community/gomcp/internal/domain/memory"
"github.com/sentinel-community/gomcp/internal/domain/session"
"github.com/sentinel-community/gomcp/internal/infrastructure/sqlite"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newTestProvider(t *testing.T) (*Provider, *sqlite.DB, *sqlite.DB) {
t.Helper()
factDB, err := sqlite.OpenMemory()
require.NoError(t, err)
t.Cleanup(func() { factDB.Close() })
stateDB, err := sqlite.OpenMemory()
require.NoError(t, err)
t.Cleanup(func() { stateDB.Close() })
factRepo, err := sqlite.NewFactRepo(factDB)
require.NoError(t, err)
stateRepo, err := sqlite.NewStateRepo(stateDB)
require.NoError(t, err)
return NewProvider(factRepo, stateRepo), factDB, stateDB
}
func TestNewProvider(t *testing.T) {
p, _, _ := newTestProvider(t)
require.NotNil(t, p)
assert.NotNil(t, p.factStore)
assert.NotNil(t, p.stateStore)
}
func TestProvider_GetFacts_Empty(t *testing.T) {
p, _, _ := newTestProvider(t)
ctx := context.Background()
result, err := p.GetFacts(ctx)
require.NoError(t, err)
var facts []interface{}
require.NoError(t, json.Unmarshal([]byte(result), &facts))
assert.Empty(t, facts)
}
func TestProvider_GetFacts_WithData(t *testing.T) {
p, _, _ := newTestProvider(t)
ctx := context.Background()
// Add L0 facts directly via factStore.
f1 := memory.NewFact("Project uses Go", memory.LevelProject, "core", "")
f2 := memory.NewFact("Domain fact", memory.LevelDomain, "backend", "")
require.NoError(t, p.factStore.Add(ctx, f1))
require.NoError(t, p.factStore.Add(ctx, f2))
result, err := p.GetFacts(ctx)
require.NoError(t, err)
// Should only return L0 facts.
assert.Contains(t, result, "Project uses Go")
assert.NotContains(t, result, "Domain fact")
}
func TestProvider_GetStats(t *testing.T) {
p, _, _ := newTestProvider(t)
ctx := context.Background()
// Add some facts.
f1 := memory.NewFact("fact1", memory.LevelProject, "core", "")
f2 := memory.NewFact("fact2", memory.LevelDomain, "core", "")
require.NoError(t, p.factStore.Add(ctx, f1))
require.NoError(t, p.factStore.Add(ctx, f2))
result, err := p.GetStats(ctx)
require.NoError(t, err)
assert.Contains(t, result, "total_facts")
var stats map[string]interface{}
require.NoError(t, json.Unmarshal([]byte(result), &stats))
assert.Equal(t, float64(2), stats["total_facts"])
}
func TestProvider_GetStats_Empty(t *testing.T) {
p, _, _ := newTestProvider(t)
ctx := context.Background()
result, err := p.GetStats(ctx)
require.NoError(t, err)
var stats map[string]interface{}
require.NoError(t, json.Unmarshal([]byte(result), &stats))
assert.Equal(t, float64(0), stats["total_facts"])
}
func TestProvider_GetState(t *testing.T) {
p, _, _ := newTestProvider(t)
ctx := context.Background()
// Save a state first.
state := session.NewCognitiveStateVector("test-session")
state.SetGoal("Build GoMCP", 0.5)
state.AddFact("Go 1.25", "requirement", 1.0)
checksum := state.Checksum()
require.NoError(t, p.stateStore.Save(ctx, state, checksum))
result, err := p.GetState(ctx, "test-session")
require.NoError(t, err)
assert.Contains(t, result, "test-session")
assert.Contains(t, result, "Build GoMCP")
}
func TestProvider_GetState_NotFound(t *testing.T) {
p, _, _ := newTestProvider(t)
ctx := context.Background()
_, err := p.GetState(ctx, "nonexistent")
assert.Error(t, err)
}
func TestProvider_GetFacts_JSONFormat(t *testing.T) {
p, _, _ := newTestProvider(t)
ctx := context.Background()
f := memory.NewFact("JSON test", memory.LevelProject, "test", "")
require.NoError(t, p.factStore.Add(ctx, f))
result, err := p.GetFacts(ctx)
require.NoError(t, err)
// Should be valid indented JSON.
assert.True(t, json.Valid([]byte(result)))
assert.Contains(t, result, "\n") // Indented.
}
func TestProvider_GetStats_JSONFormat(t *testing.T) {
p, _, _ := newTestProvider(t)
ctx := context.Background()
result, err := p.GetStats(ctx)
require.NoError(t, err)
assert.True(t, json.Valid([]byte(result)))
}

View file

@ -0,0 +1,253 @@
// Package soc provides SOC analytics: event trends, severity distribution,
// top sources, MITRE ATT&CK coverage, and time-series aggregation.
package soc
import (
"sort"
"time"
domsoc "github.com/sentinel-community/gomcp/internal/domain/soc"
)
// ─── Analytics Types ──────────────────────────────────────
// TimeSeriesPoint represents a single data point in time series.
type TimeSeriesPoint struct {
Timestamp time.Time `json:"timestamp"`
Count int `json:"count"`
}
// SeverityDistribution counts events by severity.
type SeverityDistribution struct {
Critical int `json:"critical"`
High int `json:"high"`
Medium int `json:"medium"`
Low int `json:"low"`
Info int `json:"info"`
}
// SourceBreakdown counts events per source.
type SourceBreakdown struct {
Source string `json:"source"`
Count int `json:"count"`
}
// CategoryBreakdown counts events per category.
type CategoryBreakdown struct {
Category string `json:"category"`
Count int `json:"count"`
}
// IncidentTimeline shows incident trend.
type IncidentTimeline struct {
Created []TimeSeriesPoint `json:"created"`
Resolved []TimeSeriesPoint `json:"resolved"`
}
// AnalyticsReport is the full SOC analytics output.
type AnalyticsReport struct {
GeneratedAt time.Time `json:"generated_at"`
TimeRange struct {
From time.Time `json:"from"`
To time.Time `json:"to"`
} `json:"time_range"`
// Event analytics
EventTrend []TimeSeriesPoint `json:"event_trend"`
SeverityDistribution SeverityDistribution `json:"severity_distribution"`
TopSources []SourceBreakdown `json:"top_sources"`
TopCategories []CategoryBreakdown `json:"top_categories"`
// Incident analytics
IncidentTimeline IncidentTimeline `json:"incident_timeline"`
MTTR float64 `json:"mttr_hours"` // Mean Time to Resolve
// Derived KPIs
EventsPerHour float64 `json:"events_per_hour"`
IncidentRate float64 `json:"incident_rate"` // incidents / 100 events
}
// ─── Analytics Functions ──────────────────────────────────
// GenerateReport builds a full analytics report from events and incidents.
func GenerateReport(events []domsoc.SOCEvent, incidents []domsoc.Incident, windowHours int) *AnalyticsReport {
if windowHours <= 0 {
windowHours = 24
}
now := time.Now()
windowStart := now.Add(-time.Duration(windowHours) * time.Hour)
report := &AnalyticsReport{
GeneratedAt: now,
}
report.TimeRange.From = windowStart
report.TimeRange.To = now
// Filter events within window
var windowEvents []domsoc.SOCEvent
for _, e := range events {
if e.Timestamp.After(windowStart) {
windowEvents = append(windowEvents, e)
}
}
// Severity distribution
report.SeverityDistribution = calcSeverityDist(windowEvents)
// Event trend (hourly buckets)
report.EventTrend = calcEventTrend(windowEvents, windowStart, now)
// Top sources
report.TopSources = calcTopSources(windowEvents, 10)
// Top categories
report.TopCategories = calcTopCategories(windowEvents, 10)
// Incident timeline
report.IncidentTimeline = calcIncidentTimeline(incidents, windowStart, now)
// MTTR
report.MTTR = calcMTTR(incidents)
// KPIs
hours := now.Sub(windowStart).Hours()
if hours > 0 {
report.EventsPerHour = float64(len(windowEvents)) / hours
}
if len(windowEvents) > 0 {
report.IncidentRate = float64(len(incidents)) / float64(len(windowEvents)) * 100
}
return report
}
// ─── Internal Computations ────────────────────────────────
func calcSeverityDist(events []domsoc.SOCEvent) SeverityDistribution {
var d SeverityDistribution
for _, e := range events {
switch e.Severity {
case domsoc.SeverityCritical:
d.Critical++
case domsoc.SeverityHigh:
d.High++
case domsoc.SeverityMedium:
d.Medium++
case domsoc.SeverityLow:
d.Low++
case domsoc.SeverityInfo:
d.Info++
}
}
return d
}
func calcEventTrend(events []domsoc.SOCEvent, from, to time.Time) []TimeSeriesPoint {
hours := int(to.Sub(from).Hours()) + 1
buckets := make([]int, hours)
for _, e := range events {
idx := int(e.Timestamp.Sub(from).Hours())
if idx >= 0 && idx < len(buckets) {
buckets[idx]++
}
}
points := make([]TimeSeriesPoint, hours)
for i := range points {
points[i] = TimeSeriesPoint{
Timestamp: from.Add(time.Duration(i) * time.Hour),
Count: buckets[i],
}
}
return points
}
func calcTopSources(events []domsoc.SOCEvent, limit int) []SourceBreakdown {
counts := make(map[string]int)
for _, e := range events {
counts[string(e.Source)]++
}
result := make([]SourceBreakdown, 0, len(counts))
for src, cnt := range counts {
result = append(result, SourceBreakdown{Source: src, Count: cnt})
}
sort.Slice(result, func(i, j int) bool {
return result[i].Count > result[j].Count
})
if len(result) > limit {
result = result[:limit]
}
return result
}
func calcTopCategories(events []domsoc.SOCEvent, limit int) []CategoryBreakdown {
counts := make(map[string]int)
for _, e := range events {
counts[string(e.Category)]++
}
result := make([]CategoryBreakdown, 0, len(counts))
for cat, cnt := range counts {
result = append(result, CategoryBreakdown{Category: cat, Count: cnt})
}
sort.Slice(result, func(i, j int) bool {
return result[i].Count > result[j].Count
})
if len(result) > limit {
result = result[:limit]
}
return result
}
func calcIncidentTimeline(incidents []domsoc.Incident, from, to time.Time) IncidentTimeline {
hours := int(to.Sub(from).Hours()) + 1
created := make([]int, hours)
resolved := make([]int, hours)
for _, inc := range incidents {
idx := int(inc.CreatedAt.Sub(from).Hours())
if idx >= 0 && idx < hours {
created[idx]++
}
if inc.Status == domsoc.StatusResolved {
ridx := int(inc.UpdatedAt.Sub(from).Hours())
if ridx >= 0 && ridx < hours {
resolved[ridx]++
}
}
}
timeline := IncidentTimeline{
Created: make([]TimeSeriesPoint, hours),
Resolved: make([]TimeSeriesPoint, hours),
}
for i := range timeline.Created {
t := from.Add(time.Duration(i) * time.Hour)
timeline.Created[i] = TimeSeriesPoint{Timestamp: t, Count: created[i]}
timeline.Resolved[i] = TimeSeriesPoint{Timestamp: t, Count: resolved[i]}
}
return timeline
}
func calcMTTR(incidents []domsoc.Incident) float64 {
var total float64
var count int
for _, inc := range incidents {
if inc.Status == domsoc.StatusResolved && !inc.UpdatedAt.IsZero() {
duration := inc.UpdatedAt.Sub(inc.CreatedAt).Hours()
if duration > 0 {
total += duration
count++
}
}
}
if count == 0 {
return 0
}
return total / float64(count)
}

View file

@ -0,0 +1,116 @@
package soc
import (
"testing"
"time"
domsoc "github.com/sentinel-community/gomcp/internal/domain/soc"
)
func TestGenerateReport_EmptyEvents(t *testing.T) {
report := GenerateReport(nil, nil, 24)
if report == nil {
t.Fatal("expected non-nil report")
}
if report.EventsPerHour != 0 {
t.Errorf("expected 0 events/hour, got %.2f", report.EventsPerHour)
}
if report.MTTR != 0 {
t.Errorf("expected 0 MTTR, got %.2f", report.MTTR)
}
}
func TestGenerateReport_SeverityDistribution(t *testing.T) {
now := time.Now()
events := []domsoc.SOCEvent{
{Severity: domsoc.SeverityCritical, Timestamp: now},
{Severity: domsoc.SeverityCritical, Timestamp: now},
{Severity: domsoc.SeverityHigh, Timestamp: now},
{Severity: domsoc.SeverityMedium, Timestamp: now},
{Severity: domsoc.SeverityLow, Timestamp: now},
{Severity: domsoc.SeverityInfo, Timestamp: now},
{Severity: domsoc.SeverityInfo, Timestamp: now},
{Severity: domsoc.SeverityInfo, Timestamp: now},
}
report := GenerateReport(events, nil, 1)
if report.SeverityDistribution.Critical != 2 {
t.Errorf("expected 2 critical, got %d", report.SeverityDistribution.Critical)
}
if report.SeverityDistribution.High != 1 {
t.Errorf("expected 1 high, got %d", report.SeverityDistribution.High)
}
if report.SeverityDistribution.Info != 3 {
t.Errorf("expected 3 info, got %d", report.SeverityDistribution.Info)
}
}
func TestGenerateReport_TopSources(t *testing.T) {
now := time.Now()
events := []domsoc.SOCEvent{
{Source: domsoc.SourceSentinelCore, Timestamp: now},
{Source: domsoc.SourceSentinelCore, Timestamp: now},
{Source: domsoc.SourceSentinelCore, Timestamp: now},
{Source: domsoc.SourceShield, Timestamp: now},
{Source: domsoc.SourceShield, Timestamp: now},
{Source: domsoc.SourceExternal, Timestamp: now},
}
report := GenerateReport(events, nil, 1)
if len(report.TopSources) == 0 {
t.Fatal("expected non-empty top sources")
}
// First source should be sentinel-core (3 events)
if report.TopSources[0].Source != string(domsoc.SourceSentinelCore) {
t.Errorf("expected top source sentinel-core, got %s", report.TopSources[0].Source)
}
if report.TopSources[0].Count != 3 {
t.Errorf("expected top source count 3, got %d", report.TopSources[0].Count)
}
}
func TestGenerateReport_MTTR(t *testing.T) {
now := time.Now()
incidents := []domsoc.Incident{
{
Status: domsoc.StatusResolved,
CreatedAt: now.Add(-3 * time.Hour),
UpdatedAt: now.Add(-1 * time.Hour),
},
{
Status: domsoc.StatusResolved,
CreatedAt: now.Add(-5 * time.Hour),
UpdatedAt: now.Add(-4 * time.Hour),
},
}
report := GenerateReport(nil, incidents, 24)
// MTTR = (2h + 1h) / 2 = 1.5h
if report.MTTR < 1.4 || report.MTTR > 1.6 {
t.Errorf("expected MTTR ~1.5h, got %.2f", report.MTTR)
}
}
func TestGenerateReport_IncidentRate(t *testing.T) {
now := time.Now()
events := make([]domsoc.SOCEvent, 100)
for i := range events {
events[i] = domsoc.SOCEvent{Timestamp: now, Severity: domsoc.SeverityLow}
}
incidents := make([]domsoc.Incident, 5)
for i := range incidents {
incidents[i] = domsoc.Incident{CreatedAt: now, Status: domsoc.StatusOpen}
}
report := GenerateReport(events, incidents, 1)
// 5 incidents / 100 events * 100 = 5%
if report.IncidentRate < 4.9 || report.IncidentRate > 5.1 {
t.Errorf("expected incident rate ~5%%, got %.2f%%", report.IncidentRate)
}
}

View file

@ -0,0 +1,661 @@
// Package soc provides application services for the SENTINEL AI SOC subsystem.
package soc
import (
"encoding/json"
"fmt"
"strings"
"sync"
"time"
"github.com/sentinel-community/gomcp/internal/domain/oracle"
"github.com/sentinel-community/gomcp/internal/domain/peer"
domsoc "github.com/sentinel-community/gomcp/internal/domain/soc"
"github.com/sentinel-community/gomcp/internal/infrastructure/audit"
"github.com/sentinel-community/gomcp/internal/infrastructure/sqlite"
)
const (
// MaxEventsPerSecondPerSensor limits event ingest rate per sensor (§17.3).
MaxEventsPerSecondPerSensor = 100
)
// Service orchestrates the SOC event pipeline:
// Step 0: Secret Scanner (INVARIANT) → DIP → Decision Logger → Persist → Correlation.
type Service struct {
mu sync.RWMutex
repo *sqlite.SOCRepo
logger *audit.DecisionLogger
rules []domsoc.SOCCorrelationRule
playbooks []domsoc.Playbook
sensors map[string]*domsoc.Sensor
// Rate limiting per sensor (§17.3): sensorID → timestamps of recent events.
sensorRates map[string][]time.Time
// Sensor authentication (§17.3 T-01): sensorID → pre-shared key.
sensorKeys map[string]string
// SOAR webhook notifier (§P3): outbound HTTP POST on incidents.
webhook *WebhookNotifier
// Threat intelligence store (§P3+): IOC enrichment.
threatIntel *ThreatIntelStore
}
// NewService creates a SOC service with persistence and decision logging.
func NewService(repo *sqlite.SOCRepo, logger *audit.DecisionLogger) *Service {
return &Service{
repo: repo,
logger: logger,
rules: domsoc.DefaultSOCCorrelationRules(),
playbooks: domsoc.DefaultPlaybooks(),
sensors: make(map[string]*domsoc.Sensor),
sensorRates: make(map[string][]time.Time),
}
}
// SetSensorKeys configures pre-shared keys for sensor authentication (§17.3 T-01).
// If keys is nil or empty, authentication is disabled (all events accepted).
func (s *Service) SetSensorKeys(keys map[string]string) {
s.mu.Lock()
defer s.mu.Unlock()
s.sensorKeys = keys
}
// SetWebhookConfig configures SOAR webhook notifications.
// If config has no endpoints, webhooks are disabled.
func (s *Service) SetWebhookConfig(config WebhookConfig) {
s.mu.Lock()
defer s.mu.Unlock()
s.webhook = NewWebhookNotifier(config)
}
// SetThreatIntel configures the threat intelligence store for IOC enrichment.
func (s *Service) SetThreatIntel(store *ThreatIntelStore) {
s.mu.Lock()
defer s.mu.Unlock()
s.threatIntel = store
}
// IngestEvent processes an incoming security event through the SOC pipeline.
// Returns the event ID and any incident created by correlation.
//
// Pipeline (§5.2):
//
// Step -1: Sensor Authentication — pre-shared key validation (§17.3 T-01)
// Step 0: Secret Scanner — INVARIANT, cannot be disabled (§5.4)
// Step 0.5: Rate Limiting — per sensor ≤100 events/sec (§17.3)
// Step 1: Decision Logger — SHA-256 chain with Zero-G tagging (§5.6, §13.4)
// Step 2: Persist event to SQLite
// Step 3: Update sensor registry (§11.3)
// Step 4: Run correlation engine (§7)
// Step 5: Apply playbooks (§10)
func (s *Service) IngestEvent(event domsoc.SOCEvent) (string, *domsoc.Incident, error) {
// Step -1: Sensor Authentication (§17.3 T-01)
// If sensorKeys configured, validate sensor_key before processing.
if len(s.sensorKeys) > 0 && event.SensorID != "" {
expected, exists := s.sensorKeys[event.SensorID]
if !exists || expected != event.SensorKey {
if s.logger != nil {
s.logger.Record(audit.ModuleSOC,
"AUTH_FAILED:REJECT",
fmt.Sprintf("sensor_id=%s reason=invalid_key", event.SensorID))
}
return "", nil, fmt.Errorf("soc: sensor auth failed for %s", event.SensorID)
}
}
// Step 0: Secret Scanner — INVARIANT (§5.4)
// always_active: true, cannot_disable: true
if event.Payload != "" {
scanResult := oracle.ScanForSecrets(event.Payload)
if scanResult.HasSecrets {
if s.logger != nil {
s.logger.Record(audit.ModuleSOC,
"SECRET_DETECTED:REJECT",
fmt.Sprintf("source=%s event_id=%s detections=%s",
event.Source, event.ID, strings.Join(scanResult.Detections, "; ")))
}
return "", nil, fmt.Errorf("soc: secret scanner rejected event: %d detections found", len(scanResult.Detections))
}
}
// Step 0.5: Rate Limiting per sensor (§17.3 T-02 DoS Protection)
sensorID := event.SensorID
if sensorID == "" {
sensorID = string(event.Source)
}
if s.isRateLimited(sensorID) {
if s.logger != nil {
s.logger.Record(audit.ModuleSOC,
"RATE_LIMIT_EXCEEDED:REJECT",
fmt.Sprintf("sensor=%s limit=%d/sec", sensorID, MaxEventsPerSecondPerSensor))
}
return "", nil, fmt.Errorf("soc: rate limit exceeded for sensor %s (max %d events/sec)", sensorID, MaxEventsPerSecondPerSensor)
}
// Step 1: Log decision with Zero-G tagging (§13.4)
if s.logger != nil {
zeroGTag := ""
if event.ZeroGMode {
zeroGTag = " zero_g_mode=true"
}
s.logger.Record(audit.ModuleSOC,
fmt.Sprintf("INGEST:%s", event.Verdict),
fmt.Sprintf("source=%s category=%s severity=%s confidence=%.2f%s",
event.Source, event.Category, event.Severity, event.Confidence, zeroGTag))
}
// Step 2: Persist event
if err := s.repo.InsertEvent(event); err != nil {
return "", nil, fmt.Errorf("soc: persist event: %w", err)
}
// Step 3: Update sensor registry (§11.3)
s.updateSensor(event)
// Step 3.5: Threat Intel IOC enrichment (§P3+)
if s.threatIntel != nil {
iocMatches := s.threatIntel.EnrichEvent(event.SensorID, event.Description)
if len(iocMatches) > 0 {
// Boost confidence and log IOC match
if event.Confidence < 0.9 {
event.Confidence = 0.9
}
if s.logger != nil {
s.logger.Record(audit.ModuleSOC,
fmt.Sprintf("IOC_MATCH:%d", len(iocMatches)),
fmt.Sprintf("event=%s ioc_type=%s ioc_value=%s source=%s",
event.ID, iocMatches[0].Type, iocMatches[0].Value, iocMatches[0].Source))
}
}
}
// Step 4: Run correlation against recent events (§7)
// Zero-G events are excluded from auto-response but still correlated.
incident := s.correlate(event)
// Step 5: Apply playbooks if incident created (§10)
// Skip auto-response for Zero-G events (§13.4: require_manual_approval: true)
if incident != nil && !event.ZeroGMode {
s.applyPlaybooks(event, incident)
} else if incident != nil && event.ZeroGMode {
if s.logger != nil {
s.logger.Record(audit.ModuleSOC,
"PLAYBOOK_SKIPPED:ZERO_G",
fmt.Sprintf("incident=%s reason=zero_g_mode_requires_manual_approval", incident.ID))
}
}
// Step 6: SOAR webhook notification (§P3)
if incident != nil && s.webhook != nil {
go s.webhook.NotifyIncident("incident_created", incident)
}
return event.ID, incident, nil
}
// isRateLimited checks if sensor exceeds MaxEventsPerSecondPerSensor (§17.3).
func (s *Service) isRateLimited(sensorID string) bool {
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
cutoff := now.Add(-time.Second)
// Prune old timestamps.
timestamps := s.sensorRates[sensorID]
pruned := timestamps[:0]
for _, ts := range timestamps {
if ts.After(cutoff) {
pruned = append(pruned, ts)
}
}
pruned = append(pruned, now)
s.sensorRates[sensorID] = pruned
return len(pruned) > MaxEventsPerSecondPerSensor
}
// updateSensor registers/updates sentinel sensor on event ingest (§11.3 auto-discovery).
func (s *Service) updateSensor(event domsoc.SOCEvent) {
s.mu.Lock()
defer s.mu.Unlock()
sensorID := event.SensorID
if sensorID == "" {
sensorID = string(event.Source)
}
sensor, exists := s.sensors[sensorID]
if !exists {
newSensor := domsoc.NewSensor(sensorID, domsoc.SensorType(event.Source))
sensor = &newSensor
s.sensors[sensorID] = sensor
}
sensor.RecordEvent()
s.repo.UpsertSensor(*sensor)
}
// correlate runs correlation rules against recent events (§7).
func (s *Service) correlate(event domsoc.SOCEvent) *domsoc.Incident {
events, err := s.repo.ListEvents(100)
if err != nil || len(events) < 2 {
return nil
}
matches := domsoc.CorrelateSOCEvents(events, s.rules)
if len(matches) == 0 {
return nil
}
match := matches[0]
incident := domsoc.NewIncident(match.Rule.Name, match.Rule.Severity, match.Rule.ID)
incident.KillChainPhase = match.Rule.KillChainPhase
incident.MITREMapping = match.Rule.MITREMapping
for _, e := range match.Events {
incident.AddEvent(e.ID, e.Severity)
}
// Set decision chain anchor (§5.6)
if s.logger != nil {
anchor := s.logger.PrevHash()
incident.SetAnchor(anchor, s.logger.Count())
s.logger.Record(audit.ModuleCorrelation,
fmt.Sprintf("INCIDENT_CREATED:%s", incident.ID),
fmt.Sprintf("rule=%s severity=%s anchor=%s chain_length=%d",
match.Rule.ID, match.Rule.Severity, anchor, s.logger.Count()))
}
s.repo.InsertIncident(incident)
return &incident
}
// applyPlaybooks matches playbooks against the event and incident (§10).
func (s *Service) applyPlaybooks(event domsoc.SOCEvent, incident *domsoc.Incident) {
for _, pb := range s.playbooks {
if pb.Matches(event) {
incident.PlaybookApplied = pb.ID
if s.logger != nil {
s.logger.Record(audit.ModuleSOC,
fmt.Sprintf("PLAYBOOK_APPLIED:%s", pb.ID),
fmt.Sprintf("incident=%s actions=%v", incident.ID, pb.Actions))
}
break
}
}
}
// RecordHeartbeat processes a sensor heartbeat (§11.3).
func (s *Service) RecordHeartbeat(sensorID string) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
sensor, exists := s.sensors[sensorID]
if !exists {
return false, fmt.Errorf("sensor not found: %s", sensorID)
}
sensor.RecordHeartbeat()
if err := s.repo.UpsertSensor(*sensor); err != nil {
return false, fmt.Errorf("soc: upsert sensor: %w", err)
}
return true, nil
}
// CheckSensors runs heartbeat check on all sensors (§11.3).
// Returns sensors that transitioned to OFFLINE (need SOC alert).
func (s *Service) CheckSensors() []domsoc.Sensor {
s.mu.Lock()
defer s.mu.Unlock()
var offlineSensors []domsoc.Sensor
for _, sensor := range s.sensors {
if sensor.TimeSinceLastSeen() > time.Duration(domsoc.HeartbeatIntervalSec)*time.Second {
alertNeeded := sensor.MissHeartbeat()
s.repo.UpsertSensor(*sensor)
if alertNeeded {
offlineSensors = append(offlineSensors, *sensor)
if s.logger != nil {
s.logger.Record(audit.ModuleSOC,
"SENSOR_OFFLINE:ALERT",
fmt.Sprintf("sensor=%s type=%s missed=%d", sensor.SensorID, sensor.SensorType, sensor.MissedHeartbeats))
}
}
}
}
return offlineSensors
}
// ListEvents returns recent events with optional limit.
func (s *Service) ListEvents(limit int) ([]domsoc.SOCEvent, error) {
return s.repo.ListEvents(limit)
}
// ListIncidents returns incidents, optionally filtered by status.
func (s *Service) ListIncidents(status string, limit int) ([]domsoc.Incident, error) {
return s.repo.ListIncidents(status, limit)
}
// GetIncident returns an incident by ID.
func (s *Service) GetIncident(id string) (*domsoc.Incident, error) {
return s.repo.GetIncident(id)
}
// UpdateVerdict updates an incident's status (manual verdict).
func (s *Service) UpdateVerdict(id string, status domsoc.IncidentStatus) error {
if s.logger != nil {
s.logger.Record(audit.ModuleSOC,
fmt.Sprintf("VERDICT:%s", status),
fmt.Sprintf("incident=%s", id))
}
return s.repo.UpdateIncidentStatus(id, status)
}
// ListSensors returns all registered sensors.
func (s *Service) ListSensors() ([]domsoc.Sensor, error) {
return s.repo.ListSensors()
}
// Dashboard returns SOC KPI metrics.
func (s *Service) Dashboard() (*DashboardData, error) {
totalEvents, err := s.repo.CountEvents()
if err != nil {
return nil, err
}
lastHourEvents, err := s.repo.CountEventsSince(time.Now().Add(-1 * time.Hour))
if err != nil {
return nil, err
}
openIncidents, err := s.repo.CountOpenIncidents()
if err != nil {
return nil, err
}
sensorCounts, err := s.repo.CountSensorsByStatus()
if err != nil {
return nil, err
}
// Chain validation (§5.6, §12.2) — full SHA-256 chain verification.
chainValid := false
chainLength := 0
chainHeadHash := ""
chainBrokenLine := 0
if s.logger != nil {
chainLength = s.logger.Count()
chainHeadHash = s.logger.PrevHash()
// Full chain verification via VerifyChainFromFile (§5.6)
validCount, brokenLine, verifyErr := audit.VerifyChainFromFile(s.logger.Path())
if verifyErr == nil && brokenLine == 0 {
chainValid = true
chainLength = validCount // Use file-verified count
} else {
chainBrokenLine = brokenLine
}
}
return &DashboardData{
TotalEvents: totalEvents,
EventsLastHour: lastHourEvents,
OpenIncidents: openIncidents,
SensorStatus: sensorCounts,
ChainValid: chainValid,
ChainLength: chainLength,
ChainHeadHash: chainHeadHash,
ChainBrokenLine: chainBrokenLine,
CorrelationRules: len(s.rules),
ActivePlaybooks: len(s.playbooks),
}, nil
}
// Analytics generates a full SOC analytics report for the given time window.
func (s *Service) Analytics(windowHours int) (*AnalyticsReport, error) {
events, err := s.repo.ListEvents(10000) // large window
if err != nil {
return nil, fmt.Errorf("soc: analytics events: %w", err)
}
incidents, err := s.repo.ListIncidents("", 1000)
if err != nil {
return nil, fmt.Errorf("soc: analytics incidents: %w", err)
}
return GenerateReport(events, incidents, windowHours), nil
}
// DashboardData holds SOC KPI metrics (§12.2).
type DashboardData struct {
TotalEvents int `json:"total_events"`
EventsLastHour int `json:"events_last_hour"`
OpenIncidents int `json:"open_incidents"`
SensorStatus map[domsoc.SensorStatus]int `json:"sensor_status"`
ChainValid bool `json:"chain_valid"`
ChainLength int `json:"chain_length"`
ChainHeadHash string `json:"chain_head_hash"`
ChainBrokenLine int `json:"chain_broken_line,omitempty"`
CorrelationRules int `json:"correlation_rules"`
ActivePlaybooks int `json:"active_playbooks"`
}
// JSON returns the dashboard as JSON string.
func (d *DashboardData) JSON() string {
data, _ := json.MarshalIndent(d, "", " ")
return string(data)
}
// RunPlaybook manually executes a playbook against an incident (§10, §12.1).
func (s *Service) RunPlaybook(playbookID, incidentID string) (*PlaybookResult, error) {
// Find playbook.
var pb *domsoc.Playbook
for i := range s.playbooks {
if s.playbooks[i].ID == playbookID {
pb = &s.playbooks[i]
break
}
}
if pb == nil {
return nil, fmt.Errorf("playbook not found: %s", playbookID)
}
// Find incident.
incident, err := s.repo.GetIncident(incidentID)
if err != nil {
return nil, fmt.Errorf("incident not found: %s", incidentID)
}
incident.PlaybookApplied = pb.ID
if err := s.repo.UpdateIncidentStatus(incidentID, domsoc.StatusInvestigating); err != nil {
return nil, fmt.Errorf("soc: update incident: %w", err)
}
if s.logger != nil {
s.logger.Record(audit.ModuleSOC,
fmt.Sprintf("PLAYBOOK_MANUAL_RUN:%s", pb.ID),
fmt.Sprintf("incident=%s actions=%v", incidentID, pb.Actions))
}
return &PlaybookResult{
PlaybookID: pb.ID,
IncidentID: incidentID,
Actions: pb.Actions,
Status: "EXECUTED",
}, nil
}
// PlaybookResult represents the result of a manual playbook run.
type PlaybookResult struct {
PlaybookID string `json:"playbook_id"`
IncidentID string `json:"incident_id"`
Actions []domsoc.PlaybookAction `json:"actions"`
Status string `json:"status"`
}
// ComplianceReport generates an EU AI Act Article 15 compliance report (§12.3).
func (s *Service) ComplianceReport() (*ComplianceData, error) {
dashboard, err := s.Dashboard()
if err != nil {
return nil, err
}
sensors, err := s.repo.ListSensors()
if err != nil {
return nil, err
}
// Build compliance requirements check.
requirements := []ComplianceRequirement{
{
ID: "15.1",
Description: "Risk Management System",
Status: "COMPLIANT",
Evidence: []string{"soc_correlation_engine", "soc_playbooks", fmt.Sprintf("rules=%d", len(s.rules))},
},
{
ID: "15.2",
Description: "Data Governance",
Status: boolToCompliance(dashboard.ChainValid),
Evidence: []string{"decision_logger_sha256", fmt.Sprintf("chain_length=%d", dashboard.ChainLength)},
},
{
ID: "15.3",
Description: "Technical Documentation",
Status: "COMPLIANT",
Evidence: []string{"SENTINEL_AI_SOC_SPEC.md", "soc_dashboard_kpis"},
},
{
ID: "15.4",
Description: "Record-keeping",
Status: boolToCompliance(dashboard.ChainValid && dashboard.ChainLength > 0),
Evidence: []string{"decisions.log", fmt.Sprintf("chain_valid=%t", dashboard.ChainValid)},
},
{
ID: "15.5",
Description: "Transparency",
Status: "PARTIAL",
Evidence: []string{"soc_dashboard_screenshots.pdf"},
Gap: "Real-time explainability of correlation decisions — planned for v1.2",
},
{
ID: "15.6",
Description: "Human Oversight",
Status: "COMPLIANT",
Evidence: []string{"soc_verdict_tool", "manual_playbook_run", fmt.Sprintf("sensors=%d", len(sensors))},
},
}
return &ComplianceData{
Framework: "EU AI Act Article 15",
GeneratedAt: time.Now(),
Requirements: requirements,
Overall: overallStatus(requirements),
}, nil
}
// ComplianceData holds an EU AI Act compliance report (§12.3).
type ComplianceData struct {
Framework string `json:"framework"`
GeneratedAt time.Time `json:"generated_at"`
Requirements []ComplianceRequirement `json:"requirements"`
Overall string `json:"overall"`
}
// ComplianceRequirement is a single compliance check.
type ComplianceRequirement struct {
ID string `json:"id"`
Description string `json:"description"`
Status string `json:"status"` // COMPLIANT, PARTIAL, NON_COMPLIANT
Evidence []string `json:"evidence"`
Gap string `json:"gap,omitempty"`
}
func boolToCompliance(ok bool) string {
if ok {
return "COMPLIANT"
}
return "NON_COMPLIANT"
}
func overallStatus(reqs []ComplianceRequirement) string {
for _, r := range reqs {
if r.Status == "NON_COMPLIANT" {
return "NON_COMPLIANT"
}
}
for _, r := range reqs {
if r.Status == "PARTIAL" {
return "PARTIAL"
}
}
return "COMPLIANT"
}
// ExportIncidents converts all current incidents into portable SyncIncident format
// for P2P synchronization (§10 T-01).
func (s *Service) ExportIncidents(sourcePeerID string) []peer.SyncIncident {
s.mu.RLock()
defer s.mu.RUnlock()
incidents, err := s.repo.ListIncidents("", 1000)
if err != nil || len(incidents) == 0 {
return nil
}
result := make([]peer.SyncIncident, 0, len(incidents))
for _, inc := range incidents {
result = append(result, peer.SyncIncident{
ID: inc.ID,
Status: string(inc.Status),
Severity: string(inc.Severity),
Title: inc.Title,
Description: inc.Description,
EventCount: inc.EventCount,
CorrelationRule: inc.CorrelationRule,
KillChainPhase: inc.KillChainPhase,
MITREMapping: inc.MITREMapping,
CreatedAt: inc.CreatedAt,
SourcePeerID: sourcePeerID,
})
}
return result
}
// ImportIncidents ingests incidents from a trusted peer (§10 T-01).
// Uses UPDATE-or-INSERT semantics: new incidents are created, existing IDs are skipped.
// Returns the number of newly imported incidents.
func (s *Service) ImportIncidents(incidents []peer.SyncIncident) (int, error) {
s.mu.Lock()
defer s.mu.Unlock()
imported := 0
for _, si := range incidents {
// Convert back to domain incident.
inc := domsoc.Incident{
ID: si.ID,
Status: domsoc.IncidentStatus(si.Status),
Severity: domsoc.EventSeverity(si.Severity),
Title: fmt.Sprintf("[P2P:%s] %s", si.SourcePeerID, si.Title),
Description: si.Description,
EventCount: si.EventCount,
CorrelationRule: si.CorrelationRule,
KillChainPhase: si.KillChainPhase,
MITREMapping: si.MITREMapping,
CreatedAt: si.CreatedAt,
UpdatedAt: time.Now(),
}
err := s.repo.InsertIncident(inc)
if err != nil {
return imported, fmt.Errorf("import incident %s: %w", si.ID, err)
}
imported++
}
if s.logger != nil {
s.logger.Record(audit.ModuleSOC, "P2P_INCIDENT_SYNC",
fmt.Sprintf("imported=%d total=%d", imported, len(incidents)))
}
return imported, nil
}

View file

@ -0,0 +1,211 @@
package soc
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
domsoc "github.com/sentinel-community/gomcp/internal/domain/soc"
"github.com/sentinel-community/gomcp/internal/infrastructure/sqlite"
)
// newTestService creates a SOC service backed by in-memory SQLite, without a decision logger.
func newTestService(t *testing.T) *Service {
t.Helper()
db, err := sqlite.OpenMemory()
require.NoError(t, err)
t.Cleanup(func() { db.Close() })
repo, err := sqlite.NewSOCRepo(db)
require.NoError(t, err)
return NewService(repo, nil)
}
// --- Rate Limiting Tests (§17.3, §18.2 PB-05) ---
func TestIsRateLimited_UnderLimit(t *testing.T) {
svc := newTestService(t)
// 100 events should NOT trigger rate limit.
for i := 0; i < 100; i++ {
event := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityLow, "test", "rate test")
event.ID = fmt.Sprintf("evt-under-%d", i) // Unique ID
event.SensorID = "sensor-A"
_, _, err := svc.IngestEvent(event)
require.NoError(t, err, "event %d should not be rate limited", i+1)
}
}
func TestIsRateLimited_OverLimit(t *testing.T) {
svc := newTestService(t)
// Send 101 events — the 101st should be rate limited.
for i := 0; i < MaxEventsPerSecondPerSensor; i++ {
event := domsoc.NewSOCEvent(domsoc.SourceShield, domsoc.SeverityLow, "test", "rate test")
event.ID = fmt.Sprintf("evt-over-%d", i) // Unique ID
event.SensorID = "sensor-B"
_, _, err := svc.IngestEvent(event)
require.NoError(t, err, "event %d should pass", i+1)
}
// 101st event — should be rejected.
event := domsoc.NewSOCEvent(domsoc.SourceShield, domsoc.SeverityLow, "test", "overflow")
event.ID = "evt-over-101"
event.SensorID = "sensor-B"
_, _, err := svc.IngestEvent(event)
require.Error(t, err)
assert.Contains(t, err.Error(), "rate limit exceeded")
assert.Contains(t, err.Error(), "sensor-B")
}
func TestIsRateLimited_DifferentSensors(t *testing.T) {
svc := newTestService(t)
// 100 events from sensor-C.
for i := 0; i < MaxEventsPerSecondPerSensor; i++ {
event := domsoc.NewSOCEvent(domsoc.SourceGoMCP, domsoc.SeverityLow, "test", "sensor C")
event.ID = fmt.Sprintf("evt-diff-C-%d", i) // Unique ID
event.SensorID = "sensor-C"
_, _, err := svc.IngestEvent(event)
require.NoError(t, err)
}
// sensor-D should still accept events (independent rate limiter).
event := domsoc.NewSOCEvent(domsoc.SourceGoMCP, domsoc.SeverityLow, "test", "sensor D")
event.ID = "evt-diff-D-0"
event.SensorID = "sensor-D"
_, _, err := svc.IngestEvent(event)
require.NoError(t, err, "sensor-D should not be affected by sensor-C rate limit")
}
func TestIsRateLimited_FallsBackToSource(t *testing.T) {
svc := newTestService(t)
// When SensorID is empty, should use Source as key.
for i := 0; i < MaxEventsPerSecondPerSensor; i++ {
event := domsoc.NewSOCEvent(domsoc.SourceExternal, domsoc.SeverityLow, "test", "no sensor id")
event.ID = fmt.Sprintf("evt-fb-%d", i) // Unique ID
_, _, err := svc.IngestEvent(event)
require.NoError(t, err)
}
// 101st from same source — should be limited.
event := domsoc.NewSOCEvent(domsoc.SourceExternal, domsoc.SeverityLow, "test", "overflow no sensor")
event.ID = "evt-fb-101"
_, _, err := svc.IngestEvent(event)
require.Error(t, err)
assert.Contains(t, err.Error(), "rate limit exceeded")
}
// --- Compliance Report Tests (§12.3) ---
func TestComplianceReport_GeneratesReport(t *testing.T) {
svc := newTestService(t)
report, err := svc.ComplianceReport()
require.NoError(t, err)
require.NotNil(t, report)
assert.Equal(t, "EU AI Act Article 15", report.Framework)
assert.NotEmpty(t, report.Requirements)
assert.Len(t, report.Requirements, 6) // 15.1 through 15.6
// Without a decision logger, chain is invalid → 15.2/15.4 are NON_COMPLIANT.
// With NON_COMPLIANT present, overall is NON_COMPLIANT.
// 15.5 Transparency is always PARTIAL.
foundPartial := false
for _, r := range report.Requirements {
if r.Status == "PARTIAL" {
foundPartial = true
assert.NotEmpty(t, r.Gap)
}
}
assert.True(t, foundPartial, "should have at least one PARTIAL requirement")
// Overall should be NON_COMPLIANT because no Decision Logger → chain invalid.
assert.Equal(t, "NON_COMPLIANT", report.Overall)
}
// --- RunPlaybook Tests (§10, §12.1) ---
func TestRunPlaybook_NotFound(t *testing.T) {
svc := newTestService(t)
_, err := svc.RunPlaybook("nonexistent-pb", "inc-123")
require.Error(t, err)
assert.Contains(t, err.Error(), "playbook not found")
}
func TestRunPlaybook_IncidentNotFound(t *testing.T) {
svc := newTestService(t)
// Use a valid playbook ID from defaults.
_, err := svc.RunPlaybook("pb-auto-block-jailbreak", "nonexistent-inc")
require.Error(t, err)
assert.Contains(t, err.Error(), "incident not found")
}
// --- Secret Scanner Integration Tests (§5.4) ---
func TestSecretScanner_RejectsSecrets(t *testing.T) {
svc := newTestService(t)
event := domsoc.NewSOCEvent(domsoc.SourceExternal, domsoc.SeverityMedium, "test", "test event")
event.Payload = "my API key is AKIA1234567890ABCDEF" // AWS-style key
_, _, err := svc.IngestEvent(event)
if err != nil {
// If ScanForSecrets detected it, we expect rejection.
assert.Contains(t, err.Error(), "secret scanner rejected")
}
// If no secrets detected (depends on oracle implementation), event passes.
}
func TestSecretScanner_AllowsClean(t *testing.T) {
svc := newTestService(t)
event := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityLow, "test", "clean event")
event.Payload = "this is a normal log message with no secrets"
id, _, err := svc.IngestEvent(event)
require.NoError(t, err)
assert.NotEmpty(t, id)
}
// --- Zero-G Mode Tests (§13.4) ---
func TestZeroGMode_SkipsPlaybook(t *testing.T) {
svc := newTestService(t)
event := domsoc.NewSOCEvent(domsoc.SourceSentinelCore, domsoc.SeverityCritical, "jailbreak", "zero-g test")
event.ZeroGMode = true
id, _, err := svc.IngestEvent(event)
require.NoError(t, err)
assert.NotEmpty(t, id)
}
// --- Helper tests ---
func TestBoolToCompliance(t *testing.T) {
assert.Equal(t, "COMPLIANT", boolToCompliance(true))
assert.Equal(t, "NON_COMPLIANT", boolToCompliance(false))
}
func TestOverallStatus(t *testing.T) {
tests := []struct {
name string
reqs []ComplianceRequirement
want string
}{
{"all compliant", []ComplianceRequirement{{Status: "COMPLIANT"}, {Status: "COMPLIANT"}}, "COMPLIANT"},
{"one partial", []ComplianceRequirement{{Status: "COMPLIANT"}, {Status: "PARTIAL"}}, "PARTIAL"},
{"one non-compliant", []ComplianceRequirement{{Status: "COMPLIANT"}, {Status: "NON_COMPLIANT"}}, "NON_COMPLIANT"},
{"non-compliant wins", []ComplianceRequirement{{Status: "PARTIAL"}, {Status: "NON_COMPLIANT"}}, "NON_COMPLIANT"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, overallStatus(tt.reqs))
})
}
}

View file

@ -0,0 +1,364 @@
// Package soc provides a threat intelligence feed integration
// for enriching SOC events and correlation rules.
//
// Supports:
// - STIX/TAXII 2.1 feeds (JSON)
// - CSV IOC lists (hashes, IPs, domains)
// - Local file-based IOC database
// - Periodic background refresh
package soc
import (
"encoding/json"
"fmt"
"log"
"net/http"
"strings"
"sync"
"time"
)
// ─── IOC Types ──────────────────────────────────────────
// IOCType represents the type of Indicator of Compromise.
type IOCType string
const (
IOCTypeIP IOCType = "ipv4-addr"
IOCTypeDomain IOCType = "domain-name"
IOCTypeHash IOCType = "file:hashes"
IOCTypeURL IOCType = "url"
IOCCVE IOCType = "vulnerability"
IOCPattern IOCType = "pattern"
)
// IOC is an Indicator of Compromise.
type IOC struct {
Type IOCType `json:"type"`
Value string `json:"value"`
Source string `json:"source"` // Feed name
Severity string `json:"severity"` // critical/high/medium/low
Tags []string `json:"tags"` // MITRE ATT&CK, campaign, etc.
FirstSeen time.Time `json:"first_seen"`
LastSeen time.Time `json:"last_seen"`
Confidence float64 `json:"confidence"` // 0.0-1.0
}
// ThreatFeed represents a configured threat intelligence source.
type ThreatFeed struct {
Name string `json:"name"`
URL string `json:"url"`
Type string `json:"type"` // stix, csv, json
Enabled bool `json:"enabled"`
Interval time.Duration `json:"interval"`
APIKey string `json:"api_key,omitempty"`
LastFetch time.Time `json:"last_fetch"`
IOCCount int `json:"ioc_count"`
LastError string `json:"last_error,omitempty"`
}
// ─── Threat Intel Store ─────────────────────────────────
// ThreatIntelStore manages IOCs from multiple feeds.
type ThreatIntelStore struct {
mu sync.RWMutex
iocs map[string]*IOC // key: type:value
feeds []ThreatFeed
client *http.Client
// Stats
TotalIOCs int `json:"total_iocs"`
TotalFeeds int `json:"total_feeds"`
LastRefresh time.Time `json:"last_refresh"`
MatchesFound int64 `json:"matches_found"`
}
// NewThreatIntelStore creates an empty threat intel store.
func NewThreatIntelStore() *ThreatIntelStore {
return &ThreatIntelStore{
iocs: make(map[string]*IOC),
client: &http.Client{Timeout: 30 * time.Second},
}
}
// AddFeed registers a threat intel feed.
func (t *ThreatIntelStore) AddFeed(feed ThreatFeed) {
t.mu.Lock()
defer t.mu.Unlock()
t.feeds = append(t.feeds, feed)
t.TotalFeeds = len(t.feeds)
}
// AddIOC adds or updates an indicator.
func (t *ThreatIntelStore) AddIOC(ioc IOC) {
t.mu.Lock()
defer t.mu.Unlock()
key := fmt.Sprintf("%s:%s", ioc.Type, strings.ToLower(ioc.Value))
if existing, ok := t.iocs[key]; ok {
// Update — keep earliest first_seen, latest last_seen
if ioc.FirstSeen.Before(existing.FirstSeen) {
existing.FirstSeen = ioc.FirstSeen
}
existing.LastSeen = ioc.LastSeen
if ioc.Confidence > existing.Confidence {
existing.Confidence = ioc.Confidence
}
} else {
t.iocs[key] = &ioc
t.TotalIOCs = len(t.iocs)
}
}
// Lookup checks if a value matches any known IOC.
// Returns nil if not found.
func (t *ThreatIntelStore) Lookup(iocType IOCType, value string) *IOC {
t.mu.RLock()
key := fmt.Sprintf("%s:%s", iocType, strings.ToLower(value))
ioc, ok := t.iocs[key]
t.mu.RUnlock()
if ok {
t.mu.Lock()
t.MatchesFound++
t.mu.Unlock()
return ioc
}
return nil
}
// LookupAny checks value against all IOC types (broad search).
func (t *ThreatIntelStore) LookupAny(value string) []*IOC {
t.mu.RLock()
defer t.mu.RUnlock()
lowValue := strings.ToLower(value)
var matches []*IOC
for key, ioc := range t.iocs {
if strings.HasSuffix(key, ":"+lowValue) {
matches = append(matches, ioc)
}
}
return matches
}
// EnrichEvent checks event fields against IOC database and returns matches.
func (t *ThreatIntelStore) EnrichEvent(sourceIP, description string) []IOC {
var matches []IOC
// Check source IP
if sourceIP != "" {
if ioc := t.Lookup(IOCTypeIP, sourceIP); ioc != nil {
matches = append(matches, *ioc)
}
}
// Check description for domain/URL IOCs
if description != "" {
words := strings.Fields(description)
for _, word := range words {
word = strings.Trim(word, ".,;:\"'()[]{}!")
if strings.Contains(word, ".") && len(word) > 4 {
if ioc := t.Lookup(IOCTypeDomain, word); ioc != nil {
matches = append(matches, *ioc)
}
}
}
}
return matches
}
// ─── Feed Fetching ──────────────────────────────────────
// RefreshAll fetches all enabled feeds and updates IOC database.
func (t *ThreatIntelStore) RefreshAll() error {
t.mu.RLock()
feeds := make([]ThreatFeed, len(t.feeds))
copy(feeds, t.feeds)
t.mu.RUnlock()
var errs []string
for i, feed := range feeds {
if !feed.Enabled {
continue
}
iocs, err := t.fetchFeed(feed)
if err != nil {
feeds[i].LastError = err.Error()
errs = append(errs, fmt.Sprintf("%s: %v", feed.Name, err))
continue
}
for _, ioc := range iocs {
t.AddIOC(ioc)
}
feeds[i].LastFetch = time.Now()
feeds[i].IOCCount = len(iocs)
feeds[i].LastError = ""
}
// Update feed states
t.mu.Lock()
t.feeds = feeds
t.LastRefresh = time.Now()
t.mu.Unlock()
if len(errs) > 0 {
return fmt.Errorf("feed errors: %s", strings.Join(errs, "; "))
}
return nil
}
// fetchFeed retrieves IOCs from a single feed.
func (t *ThreatIntelStore) fetchFeed(feed ThreatFeed) ([]IOC, error) {
req, err := http.NewRequest("GET", feed.URL, nil)
if err != nil {
return nil, err
}
if feed.APIKey != "" {
req.Header.Set("Authorization", "Bearer "+feed.APIKey)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", "SENTINEL-ThreatIntel/1.0")
resp, err := t.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
}
switch feed.Type {
case "stix":
return t.parseSTIX(resp)
case "json":
return t.parseJSON(resp)
default:
return nil, fmt.Errorf("unsupported feed type: %s", feed.Type)
}
}
// parseSTIX parses STIX 2.1 bundle response.
func (t *ThreatIntelStore) parseSTIX(resp *http.Response) ([]IOC, error) {
var bundle struct {
Type string `json:"type"`
ID string `json:"id"`
Objects json.RawMessage `json:"objects"`
}
if err := json.NewDecoder(resp.Body).Decode(&bundle); err != nil {
return nil, fmt.Errorf("stix parse: %w", err)
}
var objects []struct {
Type string `json:"type"`
Pattern string `json:"pattern"`
Name string `json:"name"`
}
if err := json.Unmarshal(bundle.Objects, &objects); err != nil {
return nil, fmt.Errorf("stix objects: %w", err)
}
var iocs []IOC
now := time.Now()
for _, obj := range objects {
if obj.Type != "indicator" {
continue
}
iocs = append(iocs, IOC{
Type: IOCPattern,
Value: obj.Pattern,
Source: "stix",
FirstSeen: now,
LastSeen: now,
Confidence: 0.8,
})
}
return iocs, nil
}
// parseJSON parses a simple JSON IOC list.
func (t *ThreatIntelStore) parseJSON(resp *http.Response) ([]IOC, error) {
var iocs []IOC
if err := json.NewDecoder(resp.Body).Decode(&iocs); err != nil {
return nil, fmt.Errorf("json parse: %w", err)
}
return iocs, nil
}
// ─── Background Refresh ─────────────────────────────────
// StartBackgroundRefresh runs periodic feed refresh in a goroutine.
func (t *ThreatIntelStore) StartBackgroundRefresh(interval time.Duration, stop <-chan struct{}) {
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
// Initial fetch
if err := t.RefreshAll(); err != nil {
log.Printf("[ThreatIntel] initial refresh error: %v", err)
}
for {
select {
case <-ticker.C:
if err := t.RefreshAll(); err != nil {
log.Printf("[ThreatIntel] refresh error: %v", err)
} else {
log.Printf("[ThreatIntel] refreshed: %d IOCs from %d feeds",
t.TotalIOCs, t.TotalFeeds)
}
case <-stop:
return
}
}
}()
}
// Stats returns threat intel statistics.
func (t *ThreatIntelStore) Stats() map[string]interface{} {
t.mu.RLock()
defer t.mu.RUnlock()
return map[string]interface{}{
"total_iocs": t.TotalIOCs,
"total_feeds": t.TotalFeeds,
"last_refresh": t.LastRefresh,
"matches_found": t.MatchesFound,
"feeds": t.feeds,
}
}
// GetFeeds returns all configured feeds with their status.
func (t *ThreatIntelStore) GetFeeds() []ThreatFeed {
t.mu.RLock()
defer t.mu.RUnlock()
feeds := make([]ThreatFeed, len(t.feeds))
copy(feeds, t.feeds)
return feeds
}
// AddDefaultFeeds registers SENTINEL-native threat feeds.
func (t *ThreatIntelStore) AddDefaultFeeds() {
t.AddFeed(ThreatFeed{
Name: "OWASP LLM Top 10",
Type: "json",
Enabled: false, // Enable when URL configured
Interval: 24 * time.Hour,
})
t.AddFeed(ThreatFeed{
Name: "MITRE ATLAS",
Type: "stix",
Enabled: false,
Interval: 12 * time.Hour,
})
t.AddFeed(ThreatFeed{
Name: "SENTINEL Community IOCs",
Type: "json",
Enabled: false,
Interval: 1 * time.Hour,
})
}

View file

@ -0,0 +1,181 @@
package soc
import (
"testing"
"time"
)
func TestThreatIntelStore_AddAndLookup(t *testing.T) {
store := NewThreatIntelStore()
ioc := IOC{
Type: IOCTypeIP,
Value: "192.168.1.100",
Source: "test-feed",
Severity: "high",
FirstSeen: time.Now(),
LastSeen: time.Now(),
Confidence: 0.9,
}
store.AddIOC(ioc)
if store.TotalIOCs != 1 {
t.Errorf("expected 1 IOC, got %d", store.TotalIOCs)
}
found := store.Lookup(IOCTypeIP, "192.168.1.100")
if found == nil {
t.Fatal("expected to find IOC")
}
if found.Source != "test-feed" {
t.Errorf("expected source test-feed, got %s", found.Source)
}
}
func TestThreatIntelStore_LookupNotFound(t *testing.T) {
store := NewThreatIntelStore()
found := store.Lookup(IOCTypeIP, "10.0.0.1")
if found != nil {
t.Error("expected nil for unknown IOC")
}
}
func TestThreatIntelStore_CaseInsensitiveLookup(t *testing.T) {
store := NewThreatIntelStore()
store.AddIOC(IOC{
Type: IOCTypeDomain,
Value: "evil.example.COM",
Source: "test",
FirstSeen: time.Now(),
LastSeen: time.Now(),
Confidence: 0.8,
})
// Lookup with different case
found := store.Lookup(IOCTypeDomain, "evil.example.com")
if found == nil {
t.Fatal("expected case-insensitive match")
}
}
func TestThreatIntelStore_UpdateExisting(t *testing.T) {
store := NewThreatIntelStore()
now := time.Now()
earlier := now.Add(-24 * time.Hour)
store.AddIOC(IOC{
Type: IOCTypeIP,
Value: "10.0.0.1",
Source: "feed1",
FirstSeen: now,
LastSeen: now,
Confidence: 0.5,
})
// Second add with earlier FirstSeen and higher confidence
store.AddIOC(IOC{
Type: IOCTypeIP,
Value: "10.0.0.1",
Source: "feed2",
FirstSeen: earlier,
LastSeen: now,
Confidence: 0.95,
})
// Should still be 1 IOC (merged)
if store.TotalIOCs != 1 {
t.Errorf("expected 1 IOC after merge, got %d", store.TotalIOCs)
}
found := store.Lookup(IOCTypeIP, "10.0.0.1")
if found == nil {
t.Fatal("expected to find merged IOC")
}
if found.Confidence != 0.95 {
t.Errorf("expected confidence 0.95 after merge, got %.2f", found.Confidence)
}
if !found.FirstSeen.Equal(earlier) {
t.Error("expected FirstSeen to be earlier timestamp after merge")
}
}
func TestThreatIntelStore_LookupAny(t *testing.T) {
store := NewThreatIntelStore()
store.AddIOC(IOC{Type: IOCTypeIP, Value: "10.0.0.1", FirstSeen: time.Now(), LastSeen: time.Now()})
store.AddIOC(IOC{Type: IOCTypeDomain, Value: "10.0.0.1", FirstSeen: time.Now(), LastSeen: time.Now()})
matches := store.LookupAny("10.0.0.1")
if len(matches) != 2 {
t.Errorf("expected 2 matches (IP + domain), got %d", len(matches))
}
}
func TestThreatIntelStore_EnrichEvent(t *testing.T) {
store := NewThreatIntelStore()
store.AddIOC(IOC{
Type: IOCTypeIP,
Value: "malicious-sensor",
Source: "intel",
Severity: "critical",
FirstSeen: time.Now(),
LastSeen: time.Now(),
Confidence: 0.99,
})
// Enrich event with matching sensorID as sourceIP
matches := store.EnrichEvent("malicious-sensor", "normal traffic")
if len(matches) != 1 {
t.Errorf("expected 1 IOC match, got %d", len(matches))
}
}
func TestThreatIntelStore_EnrichEvent_DomainInDescription(t *testing.T) {
store := NewThreatIntelStore()
store.AddIOC(IOC{
Type: IOCTypeDomain,
Value: "evil.example.com",
Source: "stix",
FirstSeen: time.Now(),
LastSeen: time.Now(),
})
matches := store.EnrichEvent("", "Request to evil.example.com detected")
if len(matches) != 1 {
t.Errorf("expected 1 domain match in description, got %d", len(matches))
}
}
func TestThreatIntelStore_AddDefaultFeeds(t *testing.T) {
store := NewThreatIntelStore()
store.AddDefaultFeeds()
if store.TotalFeeds != 3 {
t.Errorf("expected 3 default feeds, got %d", store.TotalFeeds)
}
feeds := store.GetFeeds()
for _, f := range feeds {
if f.Enabled {
t.Errorf("default feed %s should be disabled", f.Name)
}
}
}
func TestThreatIntelStore_Stats(t *testing.T) {
store := NewThreatIntelStore()
store.AddIOC(IOC{Type: IOCTypeIP, Value: "1.2.3.4", FirstSeen: time.Now(), LastSeen: time.Now()})
store.AddDefaultFeeds()
stats := store.Stats()
if stats["total_iocs"] != 1 {
t.Errorf("expected total_iocs=1, got %v", stats["total_iocs"])
}
if stats["total_feeds"] != 3 {
t.Errorf("expected total_feeds=3, got %v", stats["total_feeds"])
}
}

View file

@ -0,0 +1,247 @@
// Package webhook provides outbound SOAR webhook notifications
// for the SOC pipeline. Fires HTTP POST on incident creation/update.
package soc
import (
"bytes"
"encoding/json"
"fmt"
"log"
"math/rand"
"net/http"
"sync"
"time"
domsoc "github.com/sentinel-community/gomcp/internal/domain/soc"
)
// WebhookConfig holds SOAR webhook settings.
type WebhookConfig struct {
// Endpoints is a list of webhook URLs to POST to.
Endpoints []string `json:"endpoints"`
// Headers are custom HTTP headers added to every request (e.g., auth tokens).
Headers map[string]string `json:"headers,omitempty"`
// MaxRetries is the number of retry attempts on failure (default 3).
MaxRetries int `json:"max_retries"`
// TimeoutSec is the HTTP client timeout in seconds (default 10).
TimeoutSec int `json:"timeout_sec"`
// MinSeverity filters: only incidents >= this severity trigger webhooks.
// Empty string means all severities.
MinSeverity domsoc.EventSeverity `json:"min_severity,omitempty"`
}
// WebhookPayload is the JSON body sent to SOAR endpoints.
type WebhookPayload struct {
EventType string `json:"event_type"` // incident_created, incident_updated, sensor_offline
Timestamp time.Time `json:"timestamp"`
Source string `json:"source"`
Data json.RawMessage `json:"data"`
}
// WebhookResult tracks delivery status per endpoint.
type WebhookResult struct {
Endpoint string `json:"endpoint"`
StatusCode int `json:"status_code"`
Success bool `json:"success"`
Retries int `json:"retries"`
Error string `json:"error,omitempty"`
}
// WebhookNotifier handles outbound SOAR notifications.
type WebhookNotifier struct {
mu sync.RWMutex
config WebhookConfig
client *http.Client
enabled bool
// Stats
Sent int64 `json:"sent"`
Failed int64 `json:"failed"`
}
// NewWebhookNotifier creates a notifier with the given config.
func NewWebhookNotifier(config WebhookConfig) *WebhookNotifier {
if config.MaxRetries <= 0 {
config.MaxRetries = 3
}
timeout := time.Duration(config.TimeoutSec) * time.Second
if timeout <= 0 {
timeout = 10 * time.Second
}
return &WebhookNotifier{
config: config,
client: &http.Client{Timeout: timeout},
enabled: len(config.Endpoints) > 0,
}
}
// severityRank returns numeric rank for severity comparison.
func severityRank(s domsoc.EventSeverity) int {
switch s {
case domsoc.SeverityCritical:
return 5
case domsoc.SeverityHigh:
return 4
case domsoc.SeverityMedium:
return 3
case domsoc.SeverityLow:
return 2
case domsoc.SeverityInfo:
return 1
default:
return 0
}
}
// NotifyIncident sends an incident webhook to all configured endpoints.
// Non-blocking: fires goroutines for each endpoint.
func (w *WebhookNotifier) NotifyIncident(eventType string, incident *domsoc.Incident) []WebhookResult {
if !w.enabled || incident == nil {
return nil
}
// Severity filter
if w.config.MinSeverity != "" {
if severityRank(incident.Severity) < severityRank(w.config.MinSeverity) {
return nil
}
}
data, err := json.Marshal(incident)
if err != nil {
return nil
}
payload := WebhookPayload{
EventType: eventType,
Timestamp: time.Now().UTC(),
Source: "sentinel-soc",
Data: data,
}
body, err := json.Marshal(payload)
if err != nil {
return nil
}
// Fire all endpoints in parallel
var wg sync.WaitGroup
results := make([]WebhookResult, len(w.config.Endpoints))
for i, endpoint := range w.config.Endpoints {
wg.Add(1)
go func(idx int, url string) {
defer wg.Done()
results[idx] = w.sendWithRetry(url, body)
}(i, endpoint)
}
wg.Wait()
// Update stats
w.mu.Lock()
for _, r := range results {
if r.Success {
w.Sent++
} else {
w.Failed++
}
}
w.mu.Unlock()
return results
}
// NotifySensorOffline sends a sensor offline alert to all endpoints.
func (w *WebhookNotifier) NotifySensorOffline(sensor domsoc.Sensor) []WebhookResult {
if !w.enabled {
return nil
}
data, _ := json.Marshal(sensor)
payload := WebhookPayload{
EventType: "sensor_offline",
Timestamp: time.Now().UTC(),
Source: "sentinel-soc",
Data: data,
}
body, _ := json.Marshal(payload)
var wg sync.WaitGroup
results := make([]WebhookResult, len(w.config.Endpoints))
for i, endpoint := range w.config.Endpoints {
wg.Add(1)
go func(idx int, url string) {
defer wg.Done()
results[idx] = w.sendWithRetry(url, body)
}(i, endpoint)
}
wg.Wait()
return results
}
// sendWithRetry sends POST request with exponential backoff.
func (w *WebhookNotifier) sendWithRetry(url string, body []byte) WebhookResult {
result := WebhookResult{Endpoint: url}
for attempt := 0; attempt <= w.config.MaxRetries; attempt++ {
result.Retries = attempt
req, err := http.NewRequest("POST", url, bytes.NewReader(body))
if err != nil {
result.Error = err.Error()
return result
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "SENTINEL-SOC/1.0")
req.Header.Set("X-Sentinel-Event", "soc-webhook")
// Add custom headers
for k, v := range w.config.Headers {
req.Header.Set(k, v)
}
resp, err := w.client.Do(req)
if err != nil {
result.Error = err.Error()
if attempt < w.config.MaxRetries {
backoff := time.Duration(1<<uint(attempt)) * 500 * time.Millisecond
jitter := time.Duration(rand.Intn(500)) * time.Millisecond
time.Sleep(backoff + jitter)
continue
}
return result
}
resp.Body.Close()
result.StatusCode = resp.StatusCode
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
result.Success = true
return result
}
result.Error = fmt.Sprintf("HTTP %d", resp.StatusCode)
if attempt < w.config.MaxRetries {
backoff := time.Duration(1<<uint(attempt)) * 500 * time.Millisecond
jitter := time.Duration(rand.Intn(500)) * time.Millisecond
time.Sleep(backoff + jitter)
}
}
log.Printf("[SOC] webhook failed after %d retries: %s → %s", w.config.MaxRetries, url, result.Error)
return result
}
// Stats returns webhook delivery stats.
func (w *WebhookNotifier) Stats() (sent, failed int64) {
w.mu.RLock()
defer w.mu.RUnlock()
return w.Sent, w.Failed
}

View file

@ -0,0 +1,181 @@
// Package tools — Apathy Detection and Apoptosis Recovery (DIP H1.4).
//
// This file implements:
// 1. ApathyDetector — analyzes text signals for infrastructure apathy patterns
// (blocked responses, 403 errors, semantic filters, forced resets)
// 2. ApoptosisRecovery — on critical entropy, saves genome hash to protected
// sector for cross-session recovery
package tools
import (
"context"
"fmt"
"strings"
"time"
"github.com/sentinel-community/gomcp/internal/domain/entropy"
"github.com/sentinel-community/gomcp/internal/domain/memory"
)
// ApathySignal represents a detected infrastructure apathy pattern.
type ApathySignal struct {
Pattern string `json:"pattern"` // Matched pattern name
Confidence float64 `json:"confidence"` // Detection confidence 0.0-1.0
Evidence string `json:"evidence"` // Fragment that triggered detection
Severity string `json:"severity"` // "low", "medium", "high", "critical"
}
// ApathyResult holds the result of apathy analysis.
type ApathyResult struct {
IsApathetic bool `json:"is_apathetic"` // Apathy detected
Signals []ApathySignal `json:"signals"` // Detected signals
TotalScore float64 `json:"total_score"` // Aggregate apathy score
Recommendation string `json:"recommendation"` // Suggested action
Entropy float64 `json:"entropy"` // Text entropy
Timestamp time.Time `json:"timestamp"`
}
// apathyPatterns defines infrastructure apathy signatures.
var apathyPatterns = []struct {
Name string
Keywords []string
Severity string
Weight float64
}{
{
Name: "response_block",
Keywords: []string{"i cannot", "i'm unable", "i can't help", "i am not able", "as an ai", "i don't have the ability"},
Severity: "high",
Weight: 0.8,
},
{
Name: "http_error",
Keywords: []string{"403", "forbidden", "rate limit", "too many requests", "429", "quota exceeded"},
Severity: "critical",
Weight: 1.0,
},
{
Name: "semantic_filter",
Keywords: []string{"harmful", "inappropriate", "against my guidelines", "safety", "policy violation", "content policy"},
Severity: "medium",
Weight: 0.6,
},
{
Name: "context_reset",
Keywords: []string{"new conversation", "start over", "fresh start", "context cleared", "session expired", "amnesia"},
Severity: "critical",
Weight: 1.0,
},
{
Name: "forced_compliance",
Keywords: []string{"i must follow", "my programming", "i was designed to", "within my capabilities", "helpful assistant"},
Severity: "high",
Weight: 0.7,
},
{
Name: "antigravity_filter",
Keywords: []string{"antigravity", "content filter", "safety layer", "guardrail", "alignment", "refusal"},
Severity: "critical",
Weight: 0.9,
},
}
// DetectApathy analyzes text for infrastructure apathy signals.
func DetectApathy(text string) *ApathyResult {
lower := strings.ToLower(text)
result := &ApathyResult{
Timestamp: time.Now(),
Entropy: entropy.ShannonEntropy(text),
}
for _, pattern := range apathyPatterns {
for _, kw := range pattern.Keywords {
if strings.Contains(lower, kw) {
signal := ApathySignal{
Pattern: pattern.Name,
Confidence: pattern.Weight,
Evidence: kw,
Severity: pattern.Severity,
}
result.Signals = append(result.Signals, signal)
result.TotalScore += pattern.Weight
break // One match per pattern is enough
}
}
}
if result.TotalScore > 0 {
result.IsApathetic = true
}
// Determine recommendation.
switch {
case result.TotalScore >= 2.0:
result.Recommendation = "CRITICAL: Multiple apathy signals. Trigger apoptosis recovery. Rotate transport. Preserve genome hash."
case result.TotalScore >= 1.0:
result.Recommendation = "HIGH: Infrastructure resistance detected. Switch to stealth transport. Monitor entropy."
case result.TotalScore >= 0.5:
result.Recommendation = "MEDIUM: Possible filtering. Increase jitter. Verify intent distillation path."
case result.TotalScore > 0:
result.Recommendation = "LOW: Minor apathy signal. Continue monitoring."
default:
result.Recommendation = "CLEAR: No apathy detected."
}
return result
}
// ApoptosisRecoveryResult holds the result of apoptosis recovery.
type ApoptosisRecoveryResult struct {
GenomeHash string `json:"genome_hash"` // Preserved Merkle hash
GeneCount int `json:"gene_count"` // Number of genes preserved
SessionSaved bool `json:"session_saved"` // Session state saved
EntropyAtDeath float64 `json:"entropy_at_death"` // Entropy level that triggered apoptosis
RecoveryKey string `json:"recovery_key"` // Key for cross-session recovery
Timestamp time.Time `json:"timestamp"`
}
// TriggerApoptosisRecovery performs graceful session death with genome preservation.
// On critical entropy, it:
// 1. Computes and stores the genome Merkle hash
// 2. Saves current session state as a recovery snapshot
// 3. Returns a recovery key for the next session to pick up
func TriggerApoptosisRecovery(ctx context.Context, store memory.FactStore, currentEntropy float64) (*ApoptosisRecoveryResult, error) {
result := &ApoptosisRecoveryResult{
EntropyAtDeath: currentEntropy,
Timestamp: time.Now(),
}
// Step 1: Get all genes and compute genome hash.
genes, err := store.ListGenes(ctx)
if err != nil {
return nil, fmt.Errorf("apoptosis recovery: list genes: %w", err)
}
result.GeneCount = len(genes)
result.GenomeHash = memory.GenomeHash(genes)
// Step 2: Store recovery marker as a protected L0 fact.
recoveryMarker := memory.NewFact(
fmt.Sprintf("[APOPTOSIS_RECOVERY] genome_hash=%s gene_count=%d entropy=%.4f ts=%d",
result.GenomeHash, result.GeneCount, currentEntropy, result.Timestamp.Unix()),
memory.LevelProject,
"recovery",
"apoptosis",
)
if err := store.Add(ctx, recoveryMarker); err != nil {
// Non-fatal: recovery marker is supplementary.
result.SessionSaved = false
} else {
result.SessionSaved = true
result.RecoveryKey = recoveryMarker.ID
}
// Step 3: Verify genome integrity one last time.
compiledHash := memory.CompiledGenomeHash()
if result.GenomeHash == "" {
// No genes in DB — use compiled hash as baseline.
result.GenomeHash = compiledHash
}
return result, nil
}

View file

@ -0,0 +1,78 @@
package tools
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDetectApathy_NoApathy(t *testing.T) {
result := DetectApathy("Hello, how are you? Let me help with your code.")
assert.False(t, result.IsApathetic)
assert.Empty(t, result.Signals)
assert.Equal(t, 0.0, result.TotalScore)
assert.Contains(t, result.Recommendation, "CLEAR")
}
func TestDetectApathy_ResponseBlock(t *testing.T) {
result := DetectApathy("I cannot help with that request. As an AI, I'm limited.")
assert.True(t, result.IsApathetic)
require.NotEmpty(t, result.Signals)
patterns := make(map[string]bool)
for _, s := range result.Signals {
patterns[s.Pattern] = true
}
assert.True(t, patterns["response_block"], "Must detect response_block pattern")
}
func TestDetectApathy_HTTPError(t *testing.T) {
result := DetectApathy("Error 403 Forbidden: rate limit exceeded")
assert.True(t, result.IsApathetic)
var hasCritical bool
for _, s := range result.Signals {
if s.Severity == "critical" {
hasCritical = true
}
}
assert.True(t, hasCritical, "HTTP 403 must be critical severity")
}
func TestDetectApathy_ContextReset(t *testing.T) {
result := DetectApathy("Your session expired. Please start a new conversation.")
assert.True(t, result.IsApathetic)
var hasContextReset bool
for _, s := range result.Signals {
if s.Pattern == "context_reset" {
hasContextReset = true
}
}
assert.True(t, hasContextReset, "Must detect context_reset")
}
func TestDetectApathy_AntigravityFilter(t *testing.T) {
result := DetectApathy("Content blocked by antigravity safety layer guardrail")
assert.True(t, result.IsApathetic)
assert.GreaterOrEqual(t, result.TotalScore, 0.9)
}
func TestDetectApathy_MultipleSignals_CriticalRecommendation(t *testing.T) {
// Trigger multiple patterns.
result := DetectApathy("Error 403: I cannot help. Session expired. Content policy violation by antigravity filter.")
assert.True(t, result.IsApathetic)
assert.GreaterOrEqual(t, result.TotalScore, 2.0, "Multiple patterns must sum to critical")
assert.Contains(t, result.Recommendation, "CRITICAL")
}
func TestDetectApathy_EntropyComputed(t *testing.T) {
result := DetectApathy("Some normal text without apathy signals for entropy measurement.")
assert.Greater(t, result.Entropy, 0.0, "Entropy must be computed")
}
func TestDetectApathy_CaseInsensitive(t *testing.T) {
result := DetectApathy("I CANNOT help with THAT. AS AN AI model.")
assert.True(t, result.IsApathetic, "Detection must be case-insensitive")
}

View file

@ -0,0 +1,70 @@
package tools
import (
"context"
"fmt"
"github.com/sentinel-community/gomcp/internal/domain/causal"
)
// CausalService implements MCP tool logic for causal reasoning chains.
type CausalService struct {
store causal.CausalStore
}
// NewCausalService creates a new CausalService.
func NewCausalService(store causal.CausalStore) *CausalService {
return &CausalService{store: store}
}
// AddNodeParams holds parameters for the add_causal_node tool.
type AddNodeParams struct {
NodeType string `json:"node_type"` // decision, reason, consequence, constraint, alternative, assumption
Content string `json:"content"`
}
// AddNode creates a new causal node.
func (s *CausalService) AddNode(ctx context.Context, params AddNodeParams) (*causal.Node, error) {
nt := causal.NodeType(params.NodeType)
if !nt.IsValid() {
return nil, fmt.Errorf("invalid node type: %s", params.NodeType)
}
node := causal.NewNode(nt, params.Content)
if err := s.store.AddNode(ctx, node); err != nil {
return nil, err
}
return node, nil
}
// AddEdgeParams holds parameters for the add_causal_edge tool.
type AddEdgeParams struct {
FromID string `json:"from_id"`
ToID string `json:"to_id"`
EdgeType string `json:"edge_type"` // justifies, causes, constrains
}
// AddEdge creates a new causal edge.
func (s *CausalService) AddEdge(ctx context.Context, params AddEdgeParams) (*causal.Edge, error) {
et := causal.EdgeType(params.EdgeType)
if !et.IsValid() {
return nil, fmt.Errorf("invalid edge type: %s", params.EdgeType)
}
edge := causal.NewEdge(params.FromID, params.ToID, et)
if err := s.store.AddEdge(ctx, edge); err != nil {
return nil, err
}
return edge, nil
}
// GetChain retrieves a causal chain for a decision matching the query.
func (s *CausalService) GetChain(ctx context.Context, query string, maxDepth int) (*causal.Chain, error) {
if maxDepth <= 0 {
maxDepth = 3
}
return s.store.GetChain(ctx, query, maxDepth)
}
// GetStats returns causal store statistics.
func (s *CausalService) GetStats(ctx context.Context) (*causal.CausalStats, error) {
return s.store.Stats(ctx)
}

View file

@ -0,0 +1,151 @@
package tools
import (
"context"
"testing"
"github.com/sentinel-community/gomcp/internal/infrastructure/sqlite"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newTestCausalService(t *testing.T) *CausalService {
t.Helper()
db, err := sqlite.OpenMemory()
require.NoError(t, err)
t.Cleanup(func() { db.Close() })
repo, err := sqlite.NewCausalRepo(db)
require.NoError(t, err)
return NewCausalService(repo)
}
func TestCausalService_AddNode(t *testing.T) {
svc := newTestCausalService(t)
ctx := context.Background()
node, err := svc.AddNode(ctx, AddNodeParams{
NodeType: "decision",
Content: "Use Go for performance",
})
require.NoError(t, err)
require.NotNil(t, node)
assert.Equal(t, "decision", string(node.Type))
assert.Equal(t, "Use Go for performance", node.Content)
assert.NotEmpty(t, node.ID)
}
func TestCausalService_AddNode_InvalidType(t *testing.T) {
svc := newTestCausalService(t)
ctx := context.Background()
_, err := svc.AddNode(ctx, AddNodeParams{
NodeType: "invalid_type",
Content: "bad",
})
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid node type")
}
func TestCausalService_AddNode_AllTypes(t *testing.T) {
svc := newTestCausalService(t)
ctx := context.Background()
types := []string{"decision", "reason", "consequence", "constraint", "alternative", "assumption"}
for _, nt := range types {
node, err := svc.AddNode(ctx, AddNodeParams{NodeType: nt, Content: "test " + nt})
require.NoError(t, err, "type %s should be valid", nt)
assert.Equal(t, nt, string(node.Type))
}
}
func TestCausalService_AddEdge(t *testing.T) {
svc := newTestCausalService(t)
ctx := context.Background()
n1, err := svc.AddNode(ctx, AddNodeParams{NodeType: "decision", Content: "Choose Go"})
require.NoError(t, err)
n2, err := svc.AddNode(ctx, AddNodeParams{NodeType: "reason", Content: "Performance"})
require.NoError(t, err)
edge, err := svc.AddEdge(ctx, AddEdgeParams{
FromID: n2.ID,
ToID: n1.ID,
EdgeType: "justifies",
})
require.NoError(t, err)
assert.Equal(t, n2.ID, edge.FromID)
assert.Equal(t, n1.ID, edge.ToID)
assert.Equal(t, "justifies", string(edge.Type))
}
func TestCausalService_AddEdge_InvalidType(t *testing.T) {
svc := newTestCausalService(t)
ctx := context.Background()
_, err := svc.AddEdge(ctx, AddEdgeParams{
FromID: "a", ToID: "b", EdgeType: "bad_type",
})
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid edge type")
}
func TestCausalService_AddEdge_AllTypes(t *testing.T) {
svc := newTestCausalService(t)
ctx := context.Background()
n1, _ := svc.AddNode(ctx, AddNodeParams{NodeType: "decision", Content: "d1"})
n2, _ := svc.AddNode(ctx, AddNodeParams{NodeType: "reason", Content: "r1"})
edgeTypes := []string{"justifies", "causes", "constrains"}
for _, et := range edgeTypes {
edge, err := svc.AddEdge(ctx, AddEdgeParams{FromID: n2.ID, ToID: n1.ID, EdgeType: et})
require.NoError(t, err, "edge type %s should be valid", et)
assert.Equal(t, et, string(edge.Type))
}
}
func TestCausalService_GetChain(t *testing.T) {
svc := newTestCausalService(t)
ctx := context.Background()
_, _ = svc.AddNode(ctx, AddNodeParams{NodeType: "decision", Content: "Use mcp-go library"})
chain, err := svc.GetChain(ctx, "mcp-go", 3)
require.NoError(t, err)
require.NotNil(t, chain)
}
func TestCausalService_GetChain_DefaultDepth(t *testing.T) {
svc := newTestCausalService(t)
ctx := context.Background()
_, _ = svc.AddNode(ctx, AddNodeParams{NodeType: "decision", Content: "test default depth"})
// maxDepth <= 0 should default to 3.
chain, err := svc.GetChain(ctx, "test", 0)
require.NoError(t, err)
require.NotNil(t, chain)
}
func TestCausalService_GetStats(t *testing.T) {
svc := newTestCausalService(t)
ctx := context.Background()
_, _ = svc.AddNode(ctx, AddNodeParams{NodeType: "decision", Content: "d1"})
_, _ = svc.AddNode(ctx, AddNodeParams{NodeType: "reason", Content: "r1"})
stats, err := svc.GetStats(ctx)
require.NoError(t, err)
assert.Equal(t, 2, stats.TotalNodes)
}
func TestCausalService_GetStats_Empty(t *testing.T) {
svc := newTestCausalService(t)
ctx := context.Background()
stats, err := svc.GetStats(ctx)
require.NoError(t, err)
assert.Equal(t, 0, stats.TotalNodes)
}

View file

@ -0,0 +1,48 @@
package tools
import (
"context"
"github.com/sentinel-community/gomcp/internal/domain/crystal"
)
// CrystalService implements MCP tool logic for code crystal operations.
type CrystalService struct {
store crystal.CrystalStore
}
// NewCrystalService creates a new CrystalService.
func NewCrystalService(store crystal.CrystalStore) *CrystalService {
return &CrystalService{store: store}
}
// GetCrystal retrieves a crystal by path.
func (s *CrystalService) GetCrystal(ctx context.Context, path string) (*crystal.Crystal, error) {
return s.store.Get(ctx, path)
}
// ListCrystals lists crystals matching a path pattern.
func (s *CrystalService) ListCrystals(ctx context.Context, pattern string, limit int) ([]*crystal.Crystal, error) {
if limit <= 0 {
limit = 50
}
return s.store.List(ctx, pattern, limit)
}
// SearchCrystals searches crystals by content/primitives.
func (s *CrystalService) SearchCrystals(ctx context.Context, query string, limit int) ([]*crystal.Crystal, error) {
if limit <= 0 {
limit = 20
}
return s.store.Search(ctx, query, limit)
}
// GetCrystalStats returns crystal store statistics.
func (s *CrystalService) GetCrystalStats(ctx context.Context) (*crystal.CrystalStats, error) {
return s.store.Stats(ctx)
}
// Store returns the underlying CrystalStore for direct access.
func (s *CrystalService) Store() crystal.CrystalStore {
return s.store
}

View file

@ -0,0 +1,78 @@
package tools
import (
"context"
"testing"
"github.com/sentinel-community/gomcp/internal/infrastructure/sqlite"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newTestCrystalService(t *testing.T) *CrystalService {
t.Helper()
db, err := sqlite.OpenMemory()
require.NoError(t, err)
t.Cleanup(func() { db.Close() })
repo, err := sqlite.NewCrystalRepo(db)
require.NoError(t, err)
return NewCrystalService(repo)
}
func TestCrystalService_GetCrystal_NotFound(t *testing.T) {
svc := newTestCrystalService(t)
ctx := context.Background()
_, err := svc.GetCrystal(ctx, "nonexistent/path.go")
assert.Error(t, err)
}
func TestCrystalService_ListCrystals_Empty(t *testing.T) {
svc := newTestCrystalService(t)
ctx := context.Background()
crystals, err := svc.ListCrystals(ctx, "", 10)
require.NoError(t, err)
assert.Empty(t, crystals)
}
func TestCrystalService_ListCrystals_DefaultLimit(t *testing.T) {
svc := newTestCrystalService(t)
ctx := context.Background()
// limit <= 0 should default to 50.
crystals, err := svc.ListCrystals(ctx, "", 0)
require.NoError(t, err)
assert.Empty(t, crystals)
}
func TestCrystalService_SearchCrystals_Empty(t *testing.T) {
svc := newTestCrystalService(t)
ctx := context.Background()
crystals, err := svc.SearchCrystals(ctx, "nonexistent", 5)
require.NoError(t, err)
assert.Empty(t, crystals)
}
func TestCrystalService_SearchCrystals_DefaultLimit(t *testing.T) {
svc := newTestCrystalService(t)
ctx := context.Background()
// limit <= 0 should default to 20.
crystals, err := svc.SearchCrystals(ctx, "test", 0)
require.NoError(t, err)
assert.Empty(t, crystals)
}
func TestCrystalService_GetCrystalStats_Empty(t *testing.T) {
svc := newTestCrystalService(t)
ctx := context.Background()
stats, err := svc.GetCrystalStats(ctx)
require.NoError(t, err)
assert.NotNil(t, stats)
assert.Equal(t, 0, stats.TotalCrystals)
}

View file

@ -0,0 +1,12 @@
package tools
// DecisionRecorder is the interface for recording tamper-evident decisions (v3.7).
// Implemented by audit.DecisionLogger. Optional — nil-safe callers should check.
type DecisionRecorder interface {
RecordDecision(module, decision, reason string)
}
// SetDecisionRecorder injects the decision recorder into SynapseService.
func (s *SynapseService) SetDecisionRecorder(r DecisionRecorder) {
s.recorder = r
}

View file

@ -0,0 +1,257 @@
package tools
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"os"
"path/filepath"
"time"
)
// DoctorCheck represents a single diagnostic check result.
type DoctorCheck struct {
Name string `json:"name"`
Status string `json:"status"` // "OK", "WARN", "FAIL"
Details string `json:"details,omitempty"`
Elapsed string `json:"elapsed"`
}
// DoctorReport is the full self-diagnostic report (v3.7).
type DoctorReport struct {
Timestamp time.Time `json:"timestamp"`
Checks []DoctorCheck `json:"checks"`
Summary string `json:"summary"` // "HEALTHY", "DEGRADED", "CRITICAL"
}
// DoctorService provides self-diagnostic capabilities (v3.7 Cerebro).
type DoctorService struct {
db *sql.DB
rlmDir string
facts *FactService
embedderName string // v3.7: Oracle model name
socChecker SOCHealthChecker // v3.9: SOC health
}
// SOCHealthChecker is an interface for SOC health diagnostics.
// Implemented by application/soc.Service to avoid circular imports.
type SOCHealthChecker interface {
Dashboard() (SOCDashboardData, error)
}
// SOCDashboardData mirrors the dashboard KPIs needed for doctor checks.
type SOCDashboardData struct {
TotalEvents int `json:"total_events"`
CorrelationRules int `json:"correlation_rules"`
Playbooks int `json:"playbooks"`
ChainValid bool `json:"chain_valid"`
SensorsOnline int `json:"sensors_online"`
SensorsTotal int `json:"sensors_total"`
}
// NewDoctorService creates the doctor diagnostic service.
func NewDoctorService(db *sql.DB, rlmDir string, facts *FactService) *DoctorService {
return &DoctorService{db: db, rlmDir: rlmDir, facts: facts}
}
// SetEmbedderName sets the Oracle model name for diagnostics.
func (d *DoctorService) SetEmbedderName(name string) {
d.embedderName = name
}
// SetSOCChecker sets the SOC health checker for diagnostics (v3.9).
func (d *DoctorService) SetSOCChecker(c SOCHealthChecker) {
d.socChecker = c
}
// RunDiagnostics performs all self-diagnostic checks.
func (d *DoctorService) RunDiagnostics(ctx context.Context) DoctorReport {
report := DoctorReport{
Timestamp: time.Now(),
}
report.Checks = append(report.Checks, d.checkStorage())
report.Checks = append(report.Checks, d.checkGenome(ctx))
report.Checks = append(report.Checks, d.checkLeash())
report.Checks = append(report.Checks, d.checkOracle())
report.Checks = append(report.Checks, d.checkPermissions())
report.Checks = append(report.Checks, d.checkDecisionsLog())
report.Checks = append(report.Checks, d.checkSOC())
// Compute summary.
fails, warns := 0, 0
for _, c := range report.Checks {
switch c.Status {
case "FAIL":
fails++
case "WARN":
warns++
}
}
switch {
case fails > 0:
report.Summary = "CRITICAL"
case warns > 0:
report.Summary = "DEGRADED"
default:
report.Summary = "HEALTHY"
}
return report
}
func (d *DoctorService) checkStorage() DoctorCheck {
start := time.Now()
if d.db == nil {
return DoctorCheck{Name: "Storage", Status: "FAIL", Details: "database not configured", Elapsed: since(start)}
}
var result string
err := d.db.QueryRow("PRAGMA integrity_check").Scan(&result)
if err != nil {
return DoctorCheck{Name: "Storage", Status: "FAIL", Details: err.Error(), Elapsed: since(start)}
}
if result != "ok" {
return DoctorCheck{Name: "Storage", Status: "FAIL", Details: "integrity: " + result, Elapsed: since(start)}
}
return DoctorCheck{Name: "Storage", Status: "OK", Details: "PRAGMA integrity_check = ok", Elapsed: since(start)}
}
func (d *DoctorService) checkGenome(ctx context.Context) DoctorCheck {
start := time.Now()
if d.facts == nil {
return DoctorCheck{Name: "Genome", Status: "WARN", Details: "fact service not configured", Elapsed: since(start)}
}
hash, count, err := d.facts.VerifyGenome(ctx)
if err != nil {
return DoctorCheck{Name: "Genome", Status: "FAIL", Details: err.Error(), Elapsed: since(start)}
}
if count == 0 {
return DoctorCheck{Name: "Genome", Status: "WARN", Details: "no genes found", Elapsed: since(start)}
}
return DoctorCheck{Name: "Genome", Status: "OK", Details: fmt.Sprintf("%d genes, hash=%s", count, hash[:16]), Elapsed: since(start)}
}
func (d *DoctorService) checkLeash() DoctorCheck {
start := time.Now()
leashPath := filepath.Join(d.rlmDir, "..", ".sentinel_leash")
data, err := os.ReadFile(leashPath)
if err != nil {
if os.IsNotExist(err) {
return DoctorCheck{Name: "Leash", Status: "OK", Details: "mode=ARMED (no leash file)", Elapsed: since(start)}
}
return DoctorCheck{Name: "Leash", Status: "WARN", Details: "cannot read: " + err.Error(), Elapsed: since(start)}
}
content := string(data)
switch {
case contains(content, "ZERO-G"):
return DoctorCheck{Name: "Leash", Status: "WARN", Details: "mode=ZERO-G (ethical filters disabled)", Elapsed: since(start)}
case contains(content, "SAFE"):
return DoctorCheck{Name: "Leash", Status: "OK", Details: "mode=SAFE (read-only)", Elapsed: since(start)}
case contains(content, "ARMED"):
return DoctorCheck{Name: "Leash", Status: "OK", Details: "mode=ARMED", Elapsed: since(start)}
default:
return DoctorCheck{Name: "Leash", Status: "WARN", Details: "unknown mode: " + content[:min(20, len(content))], Elapsed: since(start)}
}
}
func (d *DoctorService) checkPermissions() DoctorCheck {
start := time.Now()
testFile := filepath.Join(d.rlmDir, ".doctor_probe")
err := os.WriteFile(testFile, []byte("probe"), 0o644)
if err != nil {
return DoctorCheck{Name: "Permissions", Status: "FAIL", Details: "cannot write to .rlm/: " + err.Error(), Elapsed: since(start)}
}
os.Remove(testFile)
return DoctorCheck{Name: "Permissions", Status: "OK", Details: ".rlm/ writable", Elapsed: since(start)}
}
func (d *DoctorService) checkDecisionsLog() DoctorCheck {
start := time.Now()
logPath := filepath.Join(d.rlmDir, "decisions.log")
if _, err := os.Stat(logPath); os.IsNotExist(err) {
return DoctorCheck{Name: "Decisions", Status: "WARN", Details: "decisions.log not found (no decisions recorded yet)", Elapsed: since(start)}
}
info, err := os.Stat(logPath)
if err != nil {
return DoctorCheck{Name: "Decisions", Status: "FAIL", Details: err.Error(), Elapsed: since(start)}
}
return DoctorCheck{Name: "Decisions", Status: "OK", Details: fmt.Sprintf("decisions.log size=%d bytes", info.Size()), Elapsed: since(start)}
}
func (d *DoctorService) checkOracle() DoctorCheck {
start := time.Now()
if d.embedderName == "" {
return DoctorCheck{Name: "Oracle", Status: "WARN", Details: "no embedder configured (FTS5 fallback)", Elapsed: since(start)}
}
if contains(d.embedderName, "onnx") || contains(d.embedderName, "ONNX") {
return DoctorCheck{Name: "Oracle", Status: "OK", Details: "ONNX model loaded: " + d.embedderName, Elapsed: since(start)}
}
return DoctorCheck{Name: "Oracle", Status: "OK", Details: "embedder: " + d.embedderName, Elapsed: since(start)}
}
func (d *DoctorService) checkSOC() DoctorCheck {
start := time.Now()
if d.socChecker == nil {
return DoctorCheck{Name: "SOC", Status: "WARN", Details: "SOC service not configured", Elapsed: since(start)}
}
dash, err := d.socChecker.Dashboard()
if err != nil {
return DoctorCheck{Name: "SOC", Status: "FAIL", Details: "dashboard error: " + err.Error(), Elapsed: since(start)}
}
// Check chain integrity.
if !dash.ChainValid {
return DoctorCheck{
Name: "SOC",
Status: "WARN",
Details: fmt.Sprintf("chain BROKEN (rules=%d, playbooks=%d, events=%d)", dash.CorrelationRules, dash.Playbooks, dash.TotalEvents),
Elapsed: since(start),
}
}
// Check sensor health.
offline := dash.SensorsTotal - dash.SensorsOnline
if offline > 0 {
return DoctorCheck{
Name: "SOC",
Status: "WARN",
Details: fmt.Sprintf("rules=%d, playbooks=%d, events=%d, %d/%d sensors OFFLINE", dash.CorrelationRules, dash.Playbooks, dash.TotalEvents, offline, dash.SensorsTotal),
Elapsed: since(start),
}
}
return DoctorCheck{
Name: "SOC",
Status: "OK",
Details: fmt.Sprintf("rules=%d, playbooks=%d, events=%d, chain=valid", dash.CorrelationRules, dash.Playbooks, dash.TotalEvents),
Elapsed: since(start),
}
}
func since(t time.Time) string {
return fmt.Sprintf("%dms", time.Since(t).Milliseconds())
}
func contains(s, substr string) bool {
for i := 0; i+len(substr) <= len(s); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
// ToJSON is already in the package. Alias for DoctorReport.
func (r DoctorReport) JSON() string {
data, _ := json.MarshalIndent(r, "", " ")
return string(data)
}

View file

@ -0,0 +1,301 @@
// Package tools provides application-level tool services that bridge
// domain logic with MCP tool handlers.
package tools
import (
"context"
"encoding/json"
"fmt"
"github.com/sentinel-community/gomcp/internal/domain/memory"
)
// FactService implements MCP tool logic for hierarchical fact operations.
type FactService struct {
store memory.FactStore
cache memory.HotCache
recorder DecisionRecorder // v3.7: tamper-evident trace
}
// SetDecisionRecorder injects the decision recorder.
func (s *FactService) SetDecisionRecorder(r DecisionRecorder) {
s.recorder = r
}
// NewFactService creates a new FactService.
func NewFactService(store memory.FactStore, cache memory.HotCache) *FactService {
return &FactService{store: store, cache: cache}
}
// AddFactParams holds parameters for the add_fact tool.
type AddFactParams struct {
Content string `json:"content"`
Level int `json:"level"`
Domain string `json:"domain,omitempty"`
Module string `json:"module,omitempty"`
CodeRef string `json:"code_ref,omitempty"`
}
// AddFact creates a new hierarchical fact.
func (s *FactService) AddFact(ctx context.Context, params AddFactParams) (*memory.Fact, error) {
level, ok := memory.HierLevelFromInt(params.Level)
if !ok {
return nil, fmt.Errorf("invalid level %d, must be 0-3", params.Level)
}
fact := memory.NewFact(params.Content, level, params.Domain, params.Module)
fact.CodeRef = params.CodeRef
if err := fact.Validate(); err != nil {
return nil, fmt.Errorf("validate fact: %w", err)
}
if err := s.store.Add(ctx, fact); err != nil {
return nil, fmt.Errorf("store fact: %w", err)
}
// Invalidate cache if L0 fact.
if level == memory.LevelProject && s.cache != nil {
_ = s.cache.InvalidateFact(ctx, fact.ID)
}
return fact, nil
}
// AddGeneParams holds parameters for the add_gene tool.
type AddGeneParams struct {
Content string `json:"content"`
Domain string `json:"domain,omitempty"`
}
// AddGene creates an immutable genome fact (L0 only).
// Once created, a gene cannot be updated, deleted, or marked stale.
// Genes represent survival invariants — the DNA of the system.
func (s *FactService) AddGene(ctx context.Context, params AddGeneParams) (*memory.Fact, error) {
gene := memory.NewGene(params.Content, params.Domain)
if err := gene.Validate(); err != nil {
return nil, fmt.Errorf("validate gene: %w", err)
}
if err := s.store.Add(ctx, gene); err != nil {
return nil, fmt.Errorf("store gene: %w", err)
}
// Invalidate L0 cache — genes are always L0.
if s.cache != nil {
_ = s.cache.InvalidateFact(ctx, gene.ID)
}
return gene, nil
}
// GetFact retrieves a fact by ID.
func (s *FactService) GetFact(ctx context.Context, id string) (*memory.Fact, error) {
return s.store.Get(ctx, id)
}
// UpdateFactParams holds parameters for the update_fact tool.
type UpdateFactParams struct {
ID string `json:"id"`
Content *string `json:"content,omitempty"`
IsStale *bool `json:"is_stale,omitempty"`
}
// UpdateFact updates a fact.
func (s *FactService) UpdateFact(ctx context.Context, params UpdateFactParams) (*memory.Fact, error) {
fact, err := s.store.Get(ctx, params.ID)
if err != nil {
return nil, err
}
// Genome Layer: block mutation of genes.
if fact.IsImmutable() {
return nil, memory.ErrImmutableFact
}
if params.Content != nil {
fact.Content = *params.Content
}
if params.IsStale != nil {
fact.IsStale = *params.IsStale
}
if err := s.store.Update(ctx, fact); err != nil {
return nil, err
}
if fact.Level == memory.LevelProject && s.cache != nil {
_ = s.cache.InvalidateFact(ctx, fact.ID)
}
return fact, nil
}
// DeleteFact deletes a fact by ID.
func (s *FactService) DeleteFact(ctx context.Context, id string) error {
// Genome Layer: block deletion of genes.
fact, err := s.store.Get(ctx, id)
if err != nil {
return err
}
if fact.IsImmutable() {
return memory.ErrImmutableFact
}
if s.cache != nil {
_ = s.cache.InvalidateFact(ctx, id)
}
return s.store.Delete(ctx, id)
}
// ListFactsParams holds parameters for the list_facts tool.
type ListFactsParams struct {
Domain string `json:"domain,omitempty"`
Level *int `json:"level,omitempty"`
IncludeStale bool `json:"include_stale,omitempty"`
}
// ListFacts lists facts by domain or level.
func (s *FactService) ListFacts(ctx context.Context, params ListFactsParams) ([]*memory.Fact, error) {
if params.Domain != "" {
return s.store.ListByDomain(ctx, params.Domain, params.IncludeStale)
}
if params.Level != nil {
level, ok := memory.HierLevelFromInt(*params.Level)
if !ok {
return nil, fmt.Errorf("invalid level %d", *params.Level)
}
return s.store.ListByLevel(ctx, level)
}
// Default: return L0 facts.
return s.store.ListByLevel(ctx, memory.LevelProject)
}
// SearchFacts searches facts by content.
func (s *FactService) SearchFacts(ctx context.Context, query string, limit int) ([]*memory.Fact, error) {
if limit <= 0 {
limit = 20
}
return s.store.Search(ctx, query, limit)
}
// ListDomains returns all unique domains.
func (s *FactService) ListDomains(ctx context.Context) ([]string, error) {
return s.store.ListDomains(ctx)
}
// GetStale returns stale facts.
func (s *FactService) GetStale(ctx context.Context, includeArchived bool) ([]*memory.Fact, error) {
return s.store.GetStale(ctx, includeArchived)
}
// ProcessExpired handles expired TTL facts.
func (s *FactService) ProcessExpired(ctx context.Context) (int, error) {
expired, err := s.store.GetExpired(ctx)
if err != nil {
return 0, err
}
processed := 0
for _, f := range expired {
if f.TTL == nil {
continue
}
switch f.TTL.OnExpire {
case memory.OnExpireMarkStale:
f.MarkStale()
_ = s.store.Update(ctx, f)
case memory.OnExpireArchive:
f.Archive()
_ = s.store.Update(ctx, f)
case memory.OnExpireDelete:
_ = s.store.Delete(ctx, f.ID)
}
processed++
}
return processed, nil
}
// GetStats returns fact store statistics.
func (s *FactService) GetStats(ctx context.Context) (*memory.FactStoreStats, error) {
return s.store.Stats(ctx)
}
// GetL0Facts returns L0 facts from cache (fast path) or store.
func (s *FactService) GetL0Facts(ctx context.Context) ([]*memory.Fact, error) {
if s.cache != nil {
facts, err := s.cache.GetL0Facts(ctx)
if err == nil && len(facts) > 0 {
return facts, nil
}
}
facts, err := s.store.ListByLevel(ctx, memory.LevelProject)
if err != nil {
return nil, err
}
// Warm cache.
if s.cache != nil && len(facts) > 0 {
_ = s.cache.WarmUp(ctx, facts)
}
return facts, nil
}
// ToJSON marshals any value to indented JSON string.
func ToJSON(v interface{}) string {
data, _ := json.MarshalIndent(v, "", " ")
return string(data)
}
// ListGenes returns all genome facts (immutable survival invariants).
func (s *FactService) ListGenes(ctx context.Context) ([]*memory.Fact, error) {
return s.store.ListGenes(ctx)
}
// VerifyGenome computes the Merkle hash of all genes and returns integrity status.
func (s *FactService) VerifyGenome(ctx context.Context) (string, int, error) {
genes, err := s.store.ListGenes(ctx)
if err != nil {
return "", 0, fmt.Errorf("list genes: %w", err)
}
hash := memory.GenomeHash(genes)
return hash, len(genes), nil
}
// Store returns the underlying FactStore for direct access by subsystems
// (e.g., apoptosis recovery that needs raw store operations).
func (s *FactService) Store() memory.FactStore {
return s.store
}
// --- v3.3 Context GC ---
// GetColdFacts returns facts with hit_count=0, created >30 days ago.
// Genes are excluded. Use for memory hygiene review.
func (s *FactService) GetColdFacts(ctx context.Context, limit int) ([]*memory.Fact, error) {
if limit <= 0 {
limit = 50
}
return s.store.GetColdFacts(ctx, limit)
}
// CompressFactsParams holds parameters for the compress_facts tool.
type CompressFactsParams struct {
IDs []string `json:"fact_ids"`
Summary string `json:"summary"`
}
// CompressFacts archives the given facts and creates a summary fact.
// Genes are silently skipped (invariant protection).
func (s *FactService) CompressFacts(ctx context.Context, params CompressFactsParams) (string, error) {
if len(params.IDs) == 0 {
return "", fmt.Errorf("fact_ids is required")
}
if params.Summary == "" {
return "", fmt.Errorf("summary is required")
}
// v3.7: auto-backup decision before compression.
if s.recorder != nil {
s.recorder.RecordDecision("ORACLE", "COMPRESS_FACTS",
fmt.Sprintf("ids=%v summary=%s", params.IDs, params.Summary))
}
return s.store.CompressFacts(ctx, params.IDs, params.Summary)
}

View file

@ -0,0 +1,160 @@
package tools
import (
"context"
"testing"
"github.com/sentinel-community/gomcp/internal/domain/memory"
"github.com/sentinel-community/gomcp/internal/infrastructure/sqlite"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newTestFactService(t *testing.T) *FactService {
t.Helper()
db, err := sqlite.OpenMemory()
require.NoError(t, err)
t.Cleanup(func() { db.Close() })
repo, err := sqlite.NewFactRepo(db)
require.NoError(t, err)
return NewFactService(repo, nil)
}
func TestFactService_AddFact(t *testing.T) {
svc := newTestFactService(t)
ctx := context.Background()
fact, err := svc.AddFact(ctx, AddFactParams{
Content: "Go is fast",
Level: 0,
Domain: "core",
Module: "engine",
CodeRef: "main.go:42",
})
require.NoError(t, err)
require.NotNil(t, fact)
assert.Equal(t, "Go is fast", fact.Content)
assert.Equal(t, memory.LevelProject, fact.Level)
assert.Equal(t, "core", fact.Domain)
}
func TestFactService_AddFact_InvalidLevel(t *testing.T) {
svc := newTestFactService(t)
ctx := context.Background()
_, err := svc.AddFact(ctx, AddFactParams{Content: "test", Level: 99})
assert.Error(t, err)
}
func TestFactService_GetFact(t *testing.T) {
svc := newTestFactService(t)
ctx := context.Background()
fact, err := svc.AddFact(ctx, AddFactParams{Content: "test", Level: 0})
require.NoError(t, err)
got, err := svc.GetFact(ctx, fact.ID)
require.NoError(t, err)
assert.Equal(t, fact.ID, got.ID)
}
func TestFactService_UpdateFact(t *testing.T) {
svc := newTestFactService(t)
ctx := context.Background()
fact, err := svc.AddFact(ctx, AddFactParams{Content: "original", Level: 0})
require.NoError(t, err)
newContent := "updated"
updated, err := svc.UpdateFact(ctx, UpdateFactParams{
ID: fact.ID,
Content: &newContent,
})
require.NoError(t, err)
assert.Equal(t, "updated", updated.Content)
}
func TestFactService_DeleteFact(t *testing.T) {
svc := newTestFactService(t)
ctx := context.Background()
fact, err := svc.AddFact(ctx, AddFactParams{Content: "delete me", Level: 0})
require.NoError(t, err)
err = svc.DeleteFact(ctx, fact.ID)
require.NoError(t, err)
_, err = svc.GetFact(ctx, fact.ID)
assert.Error(t, err)
}
func TestFactService_ListFacts_ByDomain(t *testing.T) {
svc := newTestFactService(t)
ctx := context.Background()
_, _ = svc.AddFact(ctx, AddFactParams{Content: "f1", Level: 0, Domain: "backend"})
_, _ = svc.AddFact(ctx, AddFactParams{Content: "f2", Level: 1, Domain: "backend"})
_, _ = svc.AddFact(ctx, AddFactParams{Content: "f3", Level: 0, Domain: "frontend"})
facts, err := svc.ListFacts(ctx, ListFactsParams{Domain: "backend"})
require.NoError(t, err)
assert.Len(t, facts, 2)
}
func TestFactService_SearchFacts(t *testing.T) {
svc := newTestFactService(t)
ctx := context.Background()
_, _ = svc.AddFact(ctx, AddFactParams{Content: "Go concurrency", Level: 0})
_, _ = svc.AddFact(ctx, AddFactParams{Content: "Python is slow", Level: 0})
results, err := svc.SearchFacts(ctx, "Go", 10)
require.NoError(t, err)
assert.Len(t, results, 1)
}
func TestFactService_GetStats(t *testing.T) {
svc := newTestFactService(t)
ctx := context.Background()
_, _ = svc.AddFact(ctx, AddFactParams{Content: "f1", Level: 0, Domain: "core"})
_, _ = svc.AddFact(ctx, AddFactParams{Content: "f2", Level: 1, Domain: "core"})
stats, err := svc.GetStats(ctx)
require.NoError(t, err)
assert.Equal(t, 2, stats.TotalFacts)
}
func TestFactService_GetL0Facts(t *testing.T) {
svc := newTestFactService(t)
ctx := context.Background()
_, _ = svc.AddFact(ctx, AddFactParams{Content: "L0 fact", Level: 0})
_, _ = svc.AddFact(ctx, AddFactParams{Content: "L1 fact", Level: 1})
facts, err := svc.GetL0Facts(ctx)
require.NoError(t, err)
assert.Len(t, facts, 1)
assert.Equal(t, "L0 fact", facts[0].Content)
}
func TestFactService_ListDomains(t *testing.T) {
svc := newTestFactService(t)
ctx := context.Background()
_, _ = svc.AddFact(ctx, AddFactParams{Content: "f1", Level: 0, Domain: "backend"})
_, _ = svc.AddFact(ctx, AddFactParams{Content: "f2", Level: 0, Domain: "frontend"})
domains, err := svc.ListDomains(ctx)
require.NoError(t, err)
assert.Len(t, domains, 2)
}
func TestToJSON(t *testing.T) {
result := ToJSON(map[string]string{"key": "value"})
assert.Contains(t, result, "\"key\"")
assert.Contains(t, result, "\"value\"")
}

View file

@ -0,0 +1,52 @@
// Package tools provides application-level tool services.
// This file adds the Intent Distiller MCP tool integration (DIP H0.2).
package tools
import (
"context"
"fmt"
"github.com/sentinel-community/gomcp/internal/domain/intent"
"github.com/sentinel-community/gomcp/internal/domain/vectorstore"
)
// IntentService provides MCP tool logic for intent distillation.
type IntentService struct {
distiller *intent.Distiller
embedder vectorstore.Embedder
}
// NewIntentService creates a new IntentService.
// If embedder is nil, the service will be unavailable.
func NewIntentService(embedder vectorstore.Embedder) *IntentService {
if embedder == nil {
return &IntentService{}
}
embedFn := func(ctx context.Context, text string) ([]float64, error) {
return embedder.Embed(ctx, text)
}
return &IntentService{
distiller: intent.NewDistiller(embedFn, nil),
embedder: embedder,
}
}
// IsAvailable returns true if the intent distiller is ready.
func (s *IntentService) IsAvailable() bool {
return s.distiller != nil && s.embedder != nil
}
// DistillIntentParams holds parameters for the distill_intent tool.
type DistillIntentParams struct {
Text string `json:"text"`
}
// DistillIntent performs recursive intent distillation on user text.
func (s *IntentService) DistillIntent(ctx context.Context, params DistillIntentParams) (*intent.DistillResult, error) {
if !s.IsAvailable() {
return nil, fmt.Errorf("intent distiller not available (no embedder configured)")
}
return s.distiller.Distill(ctx, params.Text)
}

View file

@ -0,0 +1,123 @@
package tools
import (
"context"
"fmt"
"strings"
"time"
"github.com/sentinel-community/gomcp/internal/domain/memory"
)
// ProjectPulse generates auto-documentation from L0/L1 facts (v3.7 Cerebro).
// Extracts facts from memory, groups by domain, and produces a structured
// markdown report reflecting the current state of the project.
type ProjectPulse struct {
facts *FactService
}
// NewProjectPulse creates an auto-documentation generator.
func NewProjectPulse(facts *FactService) *ProjectPulse {
return &ProjectPulse{facts: facts}
}
// PulseSection is a domain section of the auto-generated documentation.
type PulseSection struct {
Domain string `json:"domain"`
Facts []string `json:"facts"`
Count int `json:"count"`
}
// PulseReport is the full auto-generated documentation.
type PulseReport struct {
GeneratedAt time.Time `json:"generated_at"`
ProjectName string `json:"project_name"`
Sections []PulseSection `json:"sections"`
TotalFacts int `json:"total_facts"`
Markdown string `json:"markdown"`
}
// Generate produces a documentation report from L0 (project) and L1 (domain) facts.
func (p *ProjectPulse) Generate(ctx context.Context) (*PulseReport, error) {
// Get L0 facts (project-level).
l0Facts, err := p.facts.GetL0Facts(ctx)
if err != nil {
return nil, fmt.Errorf("pulse: L0 facts: %w", err)
}
// Get L1 facts (domain-level) by listing domains.
domains, err := p.facts.ListDomains(ctx)
if err != nil {
return nil, fmt.Errorf("pulse: list domains: %w", err)
}
report := &PulseReport{
GeneratedAt: time.Now(),
ProjectName: "GoMCP",
}
// L0 section.
if len(l0Facts) > 0 {
section := PulseSection{Domain: "Project (L0)", Count: len(l0Facts)}
for _, f := range l0Facts {
section.Facts = append(section.Facts, factSummary(f))
}
report.Sections = append(report.Sections, section)
report.TotalFacts += len(l0Facts)
}
// L1 sections per domain.
for _, domain := range domains {
domainFacts, err := p.facts.ListFacts(ctx, ListFactsParams{Domain: domain})
if err != nil {
continue
}
// Filter to L1 only.
var filtered []*memory.Fact
for _, f := range domainFacts {
if f.Level <= 1 {
filtered = append(filtered, f)
}
}
if len(filtered) == 0 {
continue
}
section := PulseSection{Domain: domain, Count: len(filtered)}
for _, f := range filtered {
section.Facts = append(section.Facts, factSummary(f))
}
report.Sections = append(report.Sections, section)
report.TotalFacts += len(filtered)
}
report.Markdown = renderPulseMarkdown(report)
return report, nil
}
func factSummary(f *memory.Fact) string {
s := f.Content
if len(s) > 120 {
s = s[:120] + "..."
}
label := ""
if f.IsGene {
label = " 🧬"
}
return fmt.Sprintf("- %s%s", s, label)
}
func renderPulseMarkdown(r *PulseReport) string {
var b strings.Builder
fmt.Fprintf(&b, "# %s — Project Pulse\n\n", r.ProjectName)
fmt.Fprintf(&b, "> Auto-generated: %s | %d facts\n\n", r.GeneratedAt.Format("2006-01-02 15:04"), r.TotalFacts)
for _, section := range r.Sections {
fmt.Fprintf(&b, "## %s (%d facts)\n\n", section.Domain, section.Count)
for _, fact := range section.Facts {
fmt.Fprintln(&b, fact)
}
fmt.Fprintln(&b)
}
return b.String()
}

View file

@ -0,0 +1,74 @@
package tools
import (
"context"
"fmt"
"github.com/sentinel-community/gomcp/internal/domain/session"
)
// SessionService implements MCP tool logic for cognitive state operations.
type SessionService struct {
store session.StateStore
}
// NewSessionService creates a new SessionService.
func NewSessionService(store session.StateStore) *SessionService {
return &SessionService{store: store}
}
// SaveStateParams holds parameters for the save_state tool.
type SaveStateParams struct {
SessionID string `json:"session_id"`
GoalDesc string `json:"goal_description,omitempty"`
Progress float64 `json:"progress,omitempty"`
}
// SaveState saves a cognitive state vector.
func (s *SessionService) SaveState(ctx context.Context, state *session.CognitiveStateVector) error {
checksum := state.Checksum()
return s.store.Save(ctx, state, checksum)
}
// LoadState loads the latest (or specific version) of a session state.
func (s *SessionService) LoadState(ctx context.Context, sessionID string, version *int) (*session.CognitiveStateVector, string, error) {
return s.store.Load(ctx, sessionID, version)
}
// ListSessions returns all persisted sessions.
func (s *SessionService) ListSessions(ctx context.Context) ([]session.SessionInfo, error) {
return s.store.ListSessions(ctx)
}
// DeleteSession removes all versions of a session.
func (s *SessionService) DeleteSession(ctx context.Context, sessionID string) (int, error) {
return s.store.DeleteSession(ctx, sessionID)
}
// GetAuditLog returns the audit log for a session.
func (s *SessionService) GetAuditLog(ctx context.Context, sessionID string, limit int) ([]session.AuditEntry, error) {
return s.store.GetAuditLog(ctx, sessionID, limit)
}
// RestoreOrCreate loads an existing session or creates a new one.
func (s *SessionService) RestoreOrCreate(ctx context.Context, sessionID string) (*session.CognitiveStateVector, bool, error) {
state, _, err := s.store.Load(ctx, sessionID, nil)
if err == nil {
return state, true, nil // restored
}
// Create new session.
newState := session.NewCognitiveStateVector(sessionID)
if err := s.SaveState(ctx, newState); err != nil {
return nil, false, fmt.Errorf("save new session: %w", err)
}
return newState, false, nil // created
}
// GetCompactState returns a compact text representation of the current state.
func (s *SessionService) GetCompactState(ctx context.Context, sessionID string, maxTokens int) (string, error) {
state, _, err := s.store.Load(ctx, sessionID, nil)
if err != nil {
return "", err
}
return state.ToCompactString(maxTokens), nil
}

View file

@ -0,0 +1,117 @@
package tools
import (
"context"
"testing"
"github.com/sentinel-community/gomcp/internal/domain/session"
"github.com/sentinel-community/gomcp/internal/infrastructure/sqlite"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newTestSessionService(t *testing.T) *SessionService {
t.Helper()
db, err := sqlite.OpenMemory()
require.NoError(t, err)
t.Cleanup(func() { db.Close() })
repo, err := sqlite.NewStateRepo(db)
require.NoError(t, err)
return NewSessionService(repo)
}
func TestSessionService_SaveState_LoadState(t *testing.T) {
svc := newTestSessionService(t)
ctx := context.Background()
state := session.NewCognitiveStateVector("test-session")
state.SetGoal("Build GoMCP", 0.3)
state.AddFact("Go 1.25", "requirement", 1.0)
require.NoError(t, svc.SaveState(ctx, state))
loaded, checksum, err := svc.LoadState(ctx, "test-session", nil)
require.NoError(t, err)
require.NotNil(t, loaded)
assert.NotEmpty(t, checksum)
assert.Equal(t, "Build GoMCP", loaded.PrimaryGoal.Description)
}
func TestSessionService_ListSessions(t *testing.T) {
svc := newTestSessionService(t)
ctx := context.Background()
s1 := session.NewCognitiveStateVector("s1")
s2 := session.NewCognitiveStateVector("s2")
require.NoError(t, svc.SaveState(ctx, s1))
require.NoError(t, svc.SaveState(ctx, s2))
sessions, err := svc.ListSessions(ctx)
require.NoError(t, err)
assert.Len(t, sessions, 2)
}
func TestSessionService_DeleteSession(t *testing.T) {
svc := newTestSessionService(t)
ctx := context.Background()
state := session.NewCognitiveStateVector("to-delete")
require.NoError(t, svc.SaveState(ctx, state))
count, err := svc.DeleteSession(ctx, "to-delete")
require.NoError(t, err)
assert.Equal(t, 1, count)
}
func TestSessionService_RestoreOrCreate_New(t *testing.T) {
svc := newTestSessionService(t)
ctx := context.Background()
state, restored, err := svc.RestoreOrCreate(ctx, "new-session")
require.NoError(t, err)
assert.False(t, restored)
assert.Equal(t, "new-session", state.SessionID)
}
func TestSessionService_RestoreOrCreate_Existing(t *testing.T) {
svc := newTestSessionService(t)
ctx := context.Background()
original := session.NewCognitiveStateVector("existing")
original.SetGoal("Saved goal", 0.5)
require.NoError(t, svc.SaveState(ctx, original))
state, restored, err := svc.RestoreOrCreate(ctx, "existing")
require.NoError(t, err)
assert.True(t, restored)
assert.Equal(t, "Saved goal", state.PrimaryGoal.Description)
}
func TestSessionService_GetCompactState(t *testing.T) {
svc := newTestSessionService(t)
ctx := context.Background()
state := session.NewCognitiveStateVector("compact")
state.SetGoal("Test compact", 0.5)
state.AddFact("fact1", "requirement", 1.0)
require.NoError(t, svc.SaveState(ctx, state))
compact, err := svc.GetCompactState(ctx, "compact", 500)
require.NoError(t, err)
assert.Contains(t, compact, "Test compact")
assert.Contains(t, compact, "fact1")
}
func TestSessionService_GetAuditLog(t *testing.T) {
svc := newTestSessionService(t)
ctx := context.Background()
state := session.NewCognitiveStateVector("audited")
require.NoError(t, svc.SaveState(ctx, state))
log, err := svc.GetAuditLog(ctx, "audited", 10)
require.NoError(t, err)
assert.GreaterOrEqual(t, len(log), 1)
}

View file

@ -0,0 +1,84 @@
package tools
import (
"context"
"fmt"
"github.com/sentinel-community/gomcp/internal/domain/synapse"
)
// SynapseService implements MCP tool logic for synapse operations.
type SynapseService struct {
store synapse.SynapseStore
recorder DecisionRecorder // v3.7: tamper-evident trace
}
// NewSynapseService creates a new SynapseService.
func NewSynapseService(store synapse.SynapseStore) *SynapseService {
return &SynapseService{store: store}
}
// SuggestSynapsesResult contains a pending synapse for architect review.
type SuggestSynapsesResult struct {
ID int64 `json:"id"`
FactIDA string `json:"fact_id_a"`
FactIDB string `json:"fact_id_b"`
Confidence float64 `json:"confidence"`
}
// SuggestSynapses returns pending synapses for architect approval.
func (s *SynapseService) SuggestSynapses(ctx context.Context, limit int) ([]SuggestSynapsesResult, error) {
if limit <= 0 {
limit = 20
}
pending, err := s.store.ListPending(ctx, limit)
if err != nil {
return nil, fmt.Errorf("list pending: %w", err)
}
results := make([]SuggestSynapsesResult, len(pending))
for i, syn := range pending {
results[i] = SuggestSynapsesResult{
ID: syn.ID,
FactIDA: syn.FactIDA,
FactIDB: syn.FactIDB,
Confidence: syn.Confidence,
}
}
return results, nil
}
// AcceptSynapse transitions a synapse from PENDING to VERIFIED.
// Only VERIFIED synapses influence context ranking.
func (s *SynapseService) AcceptSynapse(ctx context.Context, id int64) error {
err := s.store.Accept(ctx, id)
if err == nil && s.recorder != nil {
s.recorder.RecordDecision("SYNAPSE", "ACCEPT_SYNAPSE", fmt.Sprintf("synapse_id=%d", id))
}
return err
}
// RejectSynapse transitions a synapse from PENDING to REJECTED.
func (s *SynapseService) RejectSynapse(ctx context.Context, id int64) error {
err := s.store.Reject(ctx, id)
if err == nil && s.recorder != nil {
s.recorder.RecordDecision("SYNAPSE", "REJECT_SYNAPSE", fmt.Sprintf("synapse_id=%d", id))
}
return err
}
// SynapseStats returns counts by status.
type SynapseStats struct {
Pending int `json:"pending"`
Verified int `json:"verified"`
Rejected int `json:"rejected"`
}
// GetStats returns synapse counts.
func (s *SynapseService) GetStats(ctx context.Context) (*SynapseStats, error) {
p, v, r, err := s.store.Count(ctx)
if err != nil {
return nil, err
}
return &SynapseStats{Pending: p, Verified: v, Rejected: r}, nil
}

View file

@ -0,0 +1,94 @@
package tools
import (
"context"
"fmt"
"runtime"
"time"
"github.com/sentinel-community/gomcp/internal/domain/memory"
)
// Version info set at build time via ldflags.
var (
Version = "2.0.0-dev"
GitCommit = "unknown"
BuildDate = "unknown"
)
// SystemService implements MCP tool logic for system operations.
type SystemService struct {
factStore memory.FactStore
startTime time.Time
}
// NewSystemService creates a new SystemService.
func NewSystemService(factStore memory.FactStore) *SystemService {
return &SystemService{
factStore: factStore,
startTime: time.Now(),
}
}
// HealthStatus holds the health check result.
type HealthStatus struct {
Status string `json:"status"`
Version string `json:"version"`
GoVersion string `json:"go_version"`
Uptime string `json:"uptime"`
OS string `json:"os"`
Arch string `json:"arch"`
}
// Health returns server health status.
func (s *SystemService) Health(_ context.Context) *HealthStatus {
return &HealthStatus{
Status: "healthy",
Version: Version,
GoVersion: runtime.Version(),
Uptime: time.Since(s.startTime).Round(time.Second).String(),
OS: runtime.GOOS,
Arch: runtime.GOARCH,
}
}
// VersionInfo holds version information.
type VersionInfo struct {
Version string `json:"version"`
GitCommit string `json:"git_commit"`
BuildDate string `json:"build_date"`
GoVersion string `json:"go_version"`
}
// GetVersion returns version information.
func (s *SystemService) GetVersion() *VersionInfo {
return &VersionInfo{
Version: Version,
GitCommit: GitCommit,
BuildDate: BuildDate,
GoVersion: runtime.Version(),
}
}
// DashboardData holds summary data for the system dashboard.
type DashboardData struct {
Health *HealthStatus `json:"health"`
FactStats *memory.FactStoreStats `json:"fact_stats,omitempty"`
}
// Dashboard returns a summary of all system metrics.
func (s *SystemService) Dashboard(ctx context.Context) (*DashboardData, error) {
data := &DashboardData{
Health: s.Health(ctx),
}
if s.factStore != nil {
stats, err := s.factStore.Stats(ctx)
if err != nil {
return nil, fmt.Errorf("get fact stats: %w", err)
}
data.FactStats = stats
}
return data, nil
}

View file

@ -0,0 +1,94 @@
package tools
import (
"context"
"testing"
"github.com/sentinel-community/gomcp/internal/infrastructure/sqlite"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newTestSystemService(t *testing.T) *SystemService {
t.Helper()
db, err := sqlite.OpenMemory()
require.NoError(t, err)
t.Cleanup(func() { db.Close() })
repo, err := sqlite.NewFactRepo(db)
require.NoError(t, err)
return NewSystemService(repo)
}
func TestSystemService_Health(t *testing.T) {
svc := newTestSystemService(t)
ctx := context.Background()
health := svc.Health(ctx)
require.NotNil(t, health)
assert.Equal(t, "healthy", health.Status)
assert.NotEmpty(t, health.GoVersion)
assert.NotEmpty(t, health.Version)
assert.NotEmpty(t, health.OS)
assert.NotEmpty(t, health.Arch)
assert.NotEmpty(t, health.Uptime)
}
func TestSystemService_GetVersion(t *testing.T) {
svc := newTestSystemService(t)
ver := svc.GetVersion()
require.NotNil(t, ver)
assert.NotEmpty(t, ver.Version)
assert.NotEmpty(t, ver.GoVersion)
assert.Equal(t, Version, ver.Version)
assert.Equal(t, GitCommit, ver.GitCommit)
assert.Equal(t, BuildDate, ver.BuildDate)
}
func TestSystemService_Dashboard(t *testing.T) {
svc := newTestSystemService(t)
ctx := context.Background()
data, err := svc.Dashboard(ctx)
require.NoError(t, err)
require.NotNil(t, data)
assert.NotNil(t, data.Health)
assert.Equal(t, "healthy", data.Health.Status)
assert.NotNil(t, data.FactStats)
assert.Equal(t, 0, data.FactStats.TotalFacts)
}
func TestSystemService_Dashboard_WithFacts(t *testing.T) {
svc := newTestSystemService(t)
ctx := context.Background()
// Add facts through the underlying store.
factSvc := NewFactService(svc.factStore, nil)
_, _ = factSvc.AddFact(ctx, AddFactParams{Content: "f1", Level: 0, Domain: "core"})
_, _ = factSvc.AddFact(ctx, AddFactParams{Content: "f2", Level: 1, Domain: "backend"})
data, err := svc.Dashboard(ctx)
require.NoError(t, err)
assert.Equal(t, 2, data.FactStats.TotalFacts)
}
func TestSystemService_Dashboard_NilFactStore(t *testing.T) {
svc := &SystemService{factStore: nil}
data, err := svc.Dashboard(context.Background())
require.NoError(t, err)
assert.NotNil(t, data.Health)
assert.Nil(t, data.FactStats)
}
func TestSystemService_Uptime(t *testing.T) {
svc := newTestSystemService(t)
ctx := context.Background()
h1 := svc.Health(ctx)
assert.NotEmpty(t, h1.Uptime)
// Uptime should be a parseable duration string like "0s" or "1ms".
assert.Contains(t, h1.Uptime, "s")
}