initial: Syntrex extraction from sentinel-community (615 files)

This commit is contained in:
DmitrL-dev 2026-03-11 15:12:02 +10:00
commit 2c50c993b1
175 changed files with 32396 additions and 0 deletions

View file

@ -0,0 +1,52 @@
package contextengine
import (
"encoding/json"
"os"
ctxdomain "github.com/sentinel-community/gomcp/internal/domain/context"
)
// LoadConfig loads engine configuration from a JSON file.
// If the file does not exist, returns DefaultEngineConfig.
// If the file exists but is invalid, returns an error.
func LoadConfig(path string) (ctxdomain.EngineConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return ctxdomain.DefaultEngineConfig(), nil
}
return ctxdomain.EngineConfig{}, err
}
var cfg ctxdomain.EngineConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return ctxdomain.EngineConfig{}, err
}
// Build skip set from deserialized SkipTools slice.
cfg.BuildSkipSet()
// If skip_tools was omitted in JSON, use defaults.
if cfg.SkipTools == nil {
cfg.SkipTools = ctxdomain.DefaultSkipTools()
cfg.BuildSkipSet()
}
if err := cfg.Validate(); err != nil {
return ctxdomain.EngineConfig{}, err
}
return cfg, nil
}
// SaveDefaultConfig writes the default configuration to a JSON file.
// Useful for bootstrapping .rlm/context.json.
func SaveDefaultConfig(path string) error {
cfg := ctxdomain.DefaultEngineConfig()
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0o644)
}

View file

@ -0,0 +1,110 @@
package contextengine
import (
"os"
"path/filepath"
"testing"
ctxdomain "github.com/sentinel-community/gomcp/internal/domain/context"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLoadConfig_FileNotExists(t *testing.T) {
cfg, err := LoadConfig("/nonexistent/path/context.json")
require.NoError(t, err)
assert.Equal(t, ctxdomain.DefaultTokenBudget, cfg.TokenBudget)
assert.True(t, cfg.Enabled)
assert.NotEmpty(t, cfg.SkipTools)
}
func TestLoadConfig_ValidFile(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "context.json")
content := `{
"token_budget": 500,
"max_facts": 15,
"recency_weight": 0.3,
"frequency_weight": 0.2,
"level_weight": 0.25,
"keyword_weight": 0.25,
"decay_half_life_hours": 48,
"enabled": true,
"skip_tools": ["health", "version"]
}`
err := os.WriteFile(path, []byte(content), 0o644)
require.NoError(t, err)
cfg, err := LoadConfig(path)
require.NoError(t, err)
assert.Equal(t, 500, cfg.TokenBudget)
assert.Equal(t, 15, cfg.MaxFacts)
assert.Equal(t, 0.3, cfg.RecencyWeight)
assert.Equal(t, 48.0, cfg.DecayHalfLifeHours)
assert.True(t, cfg.Enabled)
assert.Len(t, cfg.SkipTools, 2)
assert.True(t, cfg.ShouldSkip("health"))
assert.True(t, cfg.ShouldSkip("version"))
assert.False(t, cfg.ShouldSkip("search_facts"))
}
func TestLoadConfig_InvalidJSON(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "context.json")
err := os.WriteFile(path, []byte("{invalid json"), 0o644)
require.NoError(t, err)
_, err = LoadConfig(path)
assert.Error(t, err)
}
func TestLoadConfig_InvalidConfig(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "context.json")
content := `{"token_budget": 0, "max_facts": 5, "recency_weight": 0.5, "frequency_weight": 0.5, "level_weight": 0, "keyword_weight": 0, "decay_half_life_hours": 24, "enabled": true}`
err := os.WriteFile(path, []byte(content), 0o644)
require.NoError(t, err)
_, err = LoadConfig(path)
assert.Error(t, err, "should fail validation: token_budget=0")
}
func TestLoadConfig_OmittedSkipTools_UsesDefaults(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "context.json")
content := `{
"token_budget": 300,
"max_facts": 10,
"recency_weight": 0.25,
"frequency_weight": 0.15,
"level_weight": 0.30,
"keyword_weight": 0.30,
"decay_half_life_hours": 72,
"enabled": true
}`
err := os.WriteFile(path, []byte(content), 0o644)
require.NoError(t, err)
cfg, err := LoadConfig(path)
require.NoError(t, err)
assert.NotEmpty(t, cfg.SkipTools, "omitted skip_tools should use defaults")
assert.True(t, cfg.ShouldSkip("search_facts"))
}
func TestSaveDefaultConfig(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "context.json")
err := SaveDefaultConfig(path)
require.NoError(t, err)
// Verify we can load what we saved
cfg, err := LoadConfig(path)
require.NoError(t, err)
assert.Equal(t, ctxdomain.DefaultTokenBudget, cfg.TokenBudget)
assert.True(t, cfg.Enabled)
}

View file

@ -0,0 +1,203 @@
// Package contextengine implements the Proactive Context Engine.
// It automatically injects relevant memory facts into every MCP tool response
// via ToolHandlerMiddleware, so the LLM always has context without asking.
package contextengine
import (
"context"
"log"
"strings"
"sync"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
ctxdomain "github.com/sentinel-community/gomcp/internal/domain/context"
)
// InteractionLogger records tool calls for crash-safe memory.
// Implementations must be safe for concurrent use.
type InteractionLogger interface {
Record(ctx context.Context, toolName string, args map[string]interface{}) error
}
// Engine is the Proactive Context Engine. It scores facts by relevance,
// selects the top ones within a token budget, and injects them into
// every tool response as a [MEMORY CONTEXT] block.
type Engine struct {
config ctxdomain.EngineConfig
scorer *ctxdomain.RelevanceScorer
provider ctxdomain.FactProvider
logger InteractionLogger // optional, nil = no logging
mu sync.RWMutex
accessCounts map[string]int // in-memory access counters per fact ID
}
// New creates a new Proactive Context Engine.
func New(cfg ctxdomain.EngineConfig, provider ctxdomain.FactProvider) *Engine {
return &Engine{
config: cfg,
scorer: ctxdomain.NewRelevanceScorer(cfg),
provider: provider,
accessCounts: make(map[string]int),
}
}
// SetInteractionLogger attaches an optional interaction logger for crash-safe
// tool call recording. If set, every tool call passing through the middleware
// will be recorded fire-and-forget (errors logged, never propagated).
func (e *Engine) SetInteractionLogger(l InteractionLogger) {
e.logger = l
}
// IsEnabled returns whether the engine is active.
func (e *Engine) IsEnabled() bool {
return e.config.Enabled
}
// BuildContext scores and selects relevant facts for the given tool call,
// returning a ContextFrame ready for formatting and injection.
func (e *Engine) BuildContext(toolName string, args map[string]interface{}) *ctxdomain.ContextFrame {
frame := ctxdomain.NewContextFrame(toolName, e.config.TokenBudget)
if !e.config.Enabled {
return frame
}
// Extract keywords from all string arguments
keywords := e.extractKeywordsFromArgs(args)
// Get candidate facts from provider
facts, err := e.provider.GetRelevantFacts(args)
if err != nil || len(facts) == 0 {
return frame
}
// Get current access counts snapshot
e.mu.RLock()
countsCopy := make(map[string]int, len(e.accessCounts))
for k, v := range e.accessCounts {
countsCopy[k] = v
}
e.mu.RUnlock()
// Score and rank facts
ranked := e.scorer.RankFacts(facts, keywords, countsCopy)
// Fill frame within token budget and max facts
added := 0
for _, sf := range ranked {
if added >= e.config.MaxFacts {
break
}
if frame.AddFact(sf) {
added++
// Record access for reinforcement
e.recordAccessInternal(sf.Fact.ID)
e.provider.RecordAccess(sf.Fact.ID)
}
}
return frame
}
// GetAccessCount returns the internal access count for a fact.
func (e *Engine) GetAccessCount(factID string) int {
e.mu.RLock()
defer e.mu.RUnlock()
return e.accessCounts[factID]
}
// recordAccessInternal increments the in-memory access counter.
func (e *Engine) recordAccessInternal(factID string) {
e.mu.Lock()
e.accessCounts[factID]++
e.mu.Unlock()
}
// Middleware returns a ToolHandlerMiddleware that wraps every tool handler
// to inject relevant memory context into the response and optionally
// record tool calls to the interaction log for crash-safe memory.
func (e *Engine) Middleware() server.ToolHandlerMiddleware {
return func(next server.ToolHandlerFunc) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Fire-and-forget: record this tool call in the interaction log.
// This runs BEFORE the handler so the record is persisted even if
// the process is killed mid-handler.
if e.logger != nil {
if logErr := e.logger.Record(ctx, req.Params.Name, req.GetArguments()); logErr != nil {
log.Printf("contextengine: interaction log error: %v", logErr)
}
}
// Call the original handler
result, err := next(ctx, req)
if err != nil {
return result, err
}
// Don't inject on nil result, error results, or empty content
if result == nil || result.IsError || len(result.Content) == 0 {
return result, nil
}
// Don't inject if engine is disabled
if !e.IsEnabled() {
return result, nil
}
// Don't inject for tools in the skip list
if e.config.ShouldSkip(req.Params.Name) {
return result, nil
}
// Build context frame
frame := e.BuildContext(req.Params.Name, req.GetArguments())
contextText := frame.Format()
if contextText == "" {
return result, nil
}
// Append context to the last text content block
e.appendContextToResult(result, contextText)
return result, nil
}
}
}
// appendContextToResult appends the context text to the last TextContent in the result.
func (e *Engine) appendContextToResult(result *mcp.CallToolResult, contextText string) {
for i := len(result.Content) - 1; i >= 0; i-- {
if tc, ok := result.Content[i].(mcp.TextContent); ok {
tc.Text += contextText
result.Content[i] = tc
return
}
}
// No text content found — add a new one
result.Content = append(result.Content, mcp.TextContent{
Type: "text",
Text: contextText,
})
}
// extractKeywordsFromArgs extracts keywords from all string values in the arguments map.
func (e *Engine) extractKeywordsFromArgs(args map[string]interface{}) []string {
if len(args) == 0 {
return nil
}
var allText strings.Builder
for _, v := range args {
switch val := v.(type) {
case string:
allText.WriteString(val)
allText.WriteString(" ")
}
}
return ctxdomain.ExtractKeywords(allText.String())
}

View file

@ -0,0 +1,708 @@
package contextengine
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/mark3labs/mcp-go/mcp"
"github.com/sentinel-community/gomcp/internal/domain/memory"
ctxdomain "github.com/sentinel-community/gomcp/internal/domain/context"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Mock FactProvider ---
type mockProvider struct {
mu sync.Mutex
facts []*memory.Fact
l0 []*memory.Fact
// tracks RecordAccess calls
accessed map[string]int
}
func newMockProvider(facts ...*memory.Fact) *mockProvider {
l0 := make([]*memory.Fact, 0)
for _, f := range facts {
if f.Level == memory.LevelProject {
l0 = append(l0, f)
}
}
return &mockProvider{
facts: facts,
l0: l0,
accessed: make(map[string]int),
}
}
func (m *mockProvider) GetRelevantFacts(_ map[string]interface{}) ([]*memory.Fact, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.facts, nil
}
func (m *mockProvider) GetL0Facts() ([]*memory.Fact, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.l0, nil
}
func (m *mockProvider) RecordAccess(factID string) {
m.mu.Lock()
defer m.mu.Unlock()
m.accessed[factID]++
}
// --- Engine tests ---
func TestNewEngine(t *testing.T) {
cfg := ctxdomain.DefaultEngineConfig()
provider := newMockProvider()
engine := New(cfg, provider)
require.NotNil(t, engine)
assert.True(t, engine.IsEnabled())
}
func TestNewEngine_Disabled(t *testing.T) {
cfg := ctxdomain.DefaultEngineConfig()
cfg.Enabled = false
engine := New(cfg, newMockProvider())
assert.False(t, engine.IsEnabled())
}
func TestEngine_BuildContext_NoFacts(t *testing.T) {
cfg := ctxdomain.DefaultEngineConfig()
provider := newMockProvider()
engine := New(cfg, provider)
frame := engine.BuildContext("test_tool", map[string]interface{}{
"content": "hello world",
})
assert.NotNil(t, frame)
assert.Empty(t, frame.Facts)
assert.Equal(t, "", frame.Format())
}
func TestEngine_BuildContext_WithFacts(t *testing.T) {
fact1 := memory.NewFact("Architecture uses clean layers", memory.LevelProject, "arch", "")
fact2 := memory.NewFact("TDD is mandatory for all code", memory.LevelProject, "process", "")
fact3 := memory.NewFact("Random snippet from old session", memory.LevelSnippet, "misc", "")
fact3.CreatedAt = time.Now().Add(-90 * 24 * time.Hour) // very old
provider := newMockProvider(fact1, fact2, fact3)
cfg := ctxdomain.DefaultEngineConfig()
cfg.TokenBudget = 500
engine := New(cfg, provider)
frame := engine.BuildContext("add_fact", map[string]interface{}{
"content": "architecture decision",
})
require.NotNil(t, frame)
assert.NotEmpty(t, frame.Facts)
// L0 facts should be included and ranked higher
assert.Equal(t, "add_fact", frame.ToolName)
formatted := frame.Format()
assert.Contains(t, formatted, "[MEMORY CONTEXT]")
assert.Contains(t, formatted, "[/MEMORY CONTEXT]")
}
func TestEngine_BuildContext_RespectsTokenBudget(t *testing.T) {
// Create many facts that exceed token budget
facts := make([]*memory.Fact, 50)
for i := 0; i < 50; i++ {
facts[i] = memory.NewFact(
fmt.Sprintf("Fact number %d with enough content to consume tokens in the budget allocation system", i),
memory.LevelProject, "arch", "",
)
}
provider := newMockProvider(facts...)
cfg := ctxdomain.DefaultEngineConfig()
cfg.TokenBudget = 100 // tight budget
cfg.MaxFacts = 50
engine := New(cfg, provider)
frame := engine.BuildContext("test", map[string]interface{}{"query": "test"})
assert.LessOrEqual(t, frame.TokensUsed, cfg.TokenBudget)
}
func TestEngine_BuildContext_RespectsMaxFacts(t *testing.T) {
facts := make([]*memory.Fact, 20)
for i := 0; i < 20; i++ {
facts[i] = memory.NewFact(fmt.Sprintf("Fact %d", i), memory.LevelProject, "arch", "")
}
provider := newMockProvider(facts...)
cfg := ctxdomain.DefaultEngineConfig()
cfg.MaxFacts = 5
cfg.TokenBudget = 10000 // large budget so max_facts is the limiter
engine := New(cfg, provider)
frame := engine.BuildContext("test", map[string]interface{}{"query": "fact"})
assert.LessOrEqual(t, len(frame.Facts), cfg.MaxFacts)
}
func TestEngine_BuildContext_DisabledReturnsEmpty(t *testing.T) {
fact := memory.NewFact("test", memory.LevelProject, "arch", "")
provider := newMockProvider(fact)
cfg := ctxdomain.DefaultEngineConfig()
cfg.Enabled = false
engine := New(cfg, provider)
frame := engine.BuildContext("test", map[string]interface{}{"content": "test"})
assert.Empty(t, frame.Facts)
assert.Equal(t, "", frame.Format())
}
func TestEngine_RecordsAccess(t *testing.T) {
fact1 := memory.NewFact("Architecture pattern", memory.LevelProject, "arch", "")
provider := newMockProvider(fact1)
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
frame := engine.BuildContext("test", map[string]interface{}{"content": "architecture"})
require.NotEmpty(t, frame.Facts)
// Check that RecordAccess was called on the provider
provider.mu.Lock()
count := provider.accessed[fact1.ID]
provider.mu.Unlock()
assert.Greater(t, count, 0, "RecordAccess should be called for injected facts")
}
func TestEngine_AccessCountTracking(t *testing.T) {
fact := memory.NewFact("Architecture decision", memory.LevelProject, "arch", "")
provider := newMockProvider(fact)
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
// Build context 3 times
for i := 0; i < 3; i++ {
engine.BuildContext("test", map[string]interface{}{"content": "architecture"})
}
// Internal access count should be tracked
count := engine.GetAccessCount(fact.ID)
assert.Equal(t, 3, count)
}
func TestEngine_AccessCountInfluencesRanking(t *testing.T) {
// Two similar facts but one has been accessed more
fact1 := memory.NewFact("Architecture pattern A", memory.LevelDomain, "arch", "")
fact2 := memory.NewFact("Architecture pattern B", memory.LevelDomain, "arch", "")
provider := newMockProvider(fact1, fact2)
cfg := ctxdomain.DefaultEngineConfig()
cfg.FrequencyWeight = 0.9 // heavily weight frequency
cfg.KeywordWeight = 0.01
cfg.RecencyWeight = 0.01
cfg.LevelWeight = 0.01
engine := New(cfg, provider)
// Simulate fact1 being accessed many times
for i := 0; i < 20; i++ {
engine.recordAccessInternal(fact1.ID)
}
frame := engine.BuildContext("test", map[string]interface{}{"content": "architecture pattern"})
require.GreaterOrEqual(t, len(frame.Facts), 2)
// fact1 should rank higher due to frequency
assert.Equal(t, fact1.ID, frame.Facts[0].Fact.ID)
}
// --- Middleware tests ---
func TestMiddleware_InjectsContext(t *testing.T) {
fact := memory.NewFact("Always remember: TDD first", memory.LevelProject, "process", "")
provider := newMockProvider(fact)
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
// Create a simple handler
handler := func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "Original result"},
},
}, nil
}
// Wrap with middleware
wrapped := engine.Middleware()(handler)
req := mcp.CallToolRequest{}
req.Params.Name = "test_tool"
req.Params.Arguments = map[string]interface{}{
"content": "TDD process",
}
result, err := wrapped(context.Background(), req)
require.NoError(t, err)
require.NotNil(t, result)
require.Len(t, result.Content, 1)
text := result.Content[0].(mcp.TextContent).Text
assert.Contains(t, text, "Original result")
assert.Contains(t, text, "[MEMORY CONTEXT]")
assert.Contains(t, text, "TDD first")
}
func TestMiddleware_DisabledPassesThrough(t *testing.T) {
fact := memory.NewFact("should not appear", memory.LevelProject, "test", "")
provider := newMockProvider(fact)
cfg := ctxdomain.DefaultEngineConfig()
cfg.Enabled = false
engine := New(cfg, provider)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "Original only"},
},
}, nil
}
wrapped := engine.Middleware()(handler)
result, err := wrapped(context.Background(), mcp.CallToolRequest{})
require.NoError(t, err)
text := result.Content[0].(mcp.TextContent).Text
assert.Equal(t, "Original only", text)
assert.NotContains(t, text, "[MEMORY CONTEXT]")
}
func TestMiddleware_HandlerErrorPassedThrough(t *testing.T) {
provider := newMockProvider()
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return nil, fmt.Errorf("handler error")
}
wrapped := engine.Middleware()(handler)
result, err := wrapped(context.Background(), mcp.CallToolRequest{})
assert.Error(t, err)
assert.Nil(t, result)
}
func TestMiddleware_ErrorResult_NoInjection(t *testing.T) {
fact := memory.NewFact("should not appear on errors", memory.LevelProject, "test", "")
provider := newMockProvider(fact)
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "Error: something failed"},
},
IsError: true,
}, nil
}
wrapped := engine.Middleware()(handler)
result, err := wrapped(context.Background(), mcp.CallToolRequest{})
require.NoError(t, err)
text := result.Content[0].(mcp.TextContent).Text
assert.NotContains(t, text, "[MEMORY CONTEXT]", "should not inject context on error results")
}
func TestMiddleware_EmptyContentSlice(t *testing.T) {
provider := newMockProvider(memory.NewFact("test", memory.LevelProject, "a", ""))
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{},
}, nil
}
wrapped := engine.Middleware()(handler)
result, err := wrapped(context.Background(), mcp.CallToolRequest{})
require.NoError(t, err)
// Should handle empty content gracefully
assert.NotNil(t, result)
}
func TestMiddleware_NilResult(t *testing.T) {
provider := newMockProvider()
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return nil, nil
}
wrapped := engine.Middleware()(handler)
result, err := wrapped(context.Background(), mcp.CallToolRequest{})
require.NoError(t, err)
assert.Nil(t, result)
}
func TestMiddleware_SkipListTools(t *testing.T) {
fact := memory.NewFact("Should not appear for skipped tools", memory.LevelProject, "test", "")
provider := newMockProvider(fact)
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "Facts result"},
},
}, nil
}
wrapped := engine.Middleware()(handler)
// Tools in default skip list should NOT get context injected
skipTools := []string{"search_facts", "get_fact", "get_l0_facts", "health", "version", "dashboard"}
for _, tool := range skipTools {
t.Run(tool, func(t *testing.T) {
req := mcp.CallToolRequest{}
req.Params.Name = tool
req.Params.Arguments = map[string]interface{}{"query": "test"}
result, err := wrapped(context.Background(), req)
require.NoError(t, err)
text := result.Content[0].(mcp.TextContent).Text
assert.NotContains(t, text, "[MEMORY CONTEXT]",
"tool %s is in skip list, should not get context injected", tool)
})
}
}
func TestMiddleware_NonSkipToolGetsContext(t *testing.T) {
fact := memory.NewFact("Important architecture fact", memory.LevelProject, "arch", "")
provider := newMockProvider(fact)
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "Tool result"},
},
}, nil
}
wrapped := engine.Middleware()(handler)
req := mcp.CallToolRequest{}
req.Params.Name = "add_causal_node"
req.Params.Arguments = map[string]interface{}{"content": "architecture decision"}
result, err := wrapped(context.Background(), req)
require.NoError(t, err)
text := result.Content[0].(mcp.TextContent).Text
assert.Contains(t, text, "[MEMORY CONTEXT]")
}
// --- Concurrency test ---
func TestEngine_ConcurrentAccess(t *testing.T) {
facts := make([]*memory.Fact, 10)
for i := 0; i < 10; i++ {
facts[i] = memory.NewFact(fmt.Sprintf("Concurrent fact %d", i), memory.LevelProject, "arch", "")
}
provider := newMockProvider(facts...)
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
var wg sync.WaitGroup
for i := 0; i < 20; i++ {
wg.Add(1)
go func(n int) {
defer wg.Done()
engine.BuildContext("tool", map[string]interface{}{
"content": fmt.Sprintf("query %d", n),
})
}(i)
}
wg.Wait()
// Just verify no panics or races (run with -race)
for _, f := range facts {
count := engine.GetAccessCount(f.ID)
assert.GreaterOrEqual(t, count, 0)
}
}
// --- Benchmark ---
func BenchmarkEngine_BuildContext(b *testing.B) {
facts := make([]*memory.Fact, 100)
for i := 0; i < 100; i++ {
facts[i] = memory.NewFact(
"Architecture uses clean layers with dependency injection for modularity",
memory.HierLevel(i%4), "arch", "core",
)
}
provider := newMockProvider(facts...)
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
args := map[string]interface{}{
"content": "architecture clean layers dependency",
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
engine.BuildContext("test_tool", args)
}
}
func BenchmarkMiddleware(b *testing.B) {
facts := make([]*memory.Fact, 50)
for i := 0; i < 50; i++ {
facts[i] = memory.NewFact("test fact content", memory.LevelProject, "arch", "")
}
provider := newMockProvider(facts...)
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "result"},
},
}, nil
}
wrapped := engine.Middleware()(handler)
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{"content": "test"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = wrapped(context.Background(), req)
}
}
// --- Mock InteractionLogger ---
type mockInteractionLogger struct {
mu sync.Mutex
entries []logEntry
failErr error // if set, Record returns this error
}
type logEntry struct {
toolName string
args map[string]interface{}
}
func (m *mockInteractionLogger) Record(_ context.Context, toolName string, args map[string]interface{}) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.failErr != nil {
return m.failErr
}
m.entries = append(m.entries, logEntry{toolName: toolName, args: args})
return nil
}
func (m *mockInteractionLogger) count() int {
m.mu.Lock()
defer m.mu.Unlock()
return len(m.entries)
}
func (m *mockInteractionLogger) lastToolName() string {
m.mu.Lock()
defer m.mu.Unlock()
if len(m.entries) == 0 {
return ""
}
return m.entries[len(m.entries)-1].toolName
}
// --- Interaction Logger Tests ---
func TestMiddleware_InteractionLogger_RecordsToolCalls(t *testing.T) {
provider := newMockProvider()
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
logger := &mockInteractionLogger{}
engine.SetInteractionLogger(logger)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "ok"},
},
}, nil
}
wrapped := engine.Middleware()(handler)
req := mcp.CallToolRequest{}
req.Params.Name = "add_fact"
req.Params.Arguments = map[string]interface{}{"content": "test fact"}
_, err := wrapped(context.Background(), req)
require.NoError(t, err)
assert.Equal(t, 1, logger.count())
assert.Equal(t, "add_fact", logger.lastToolName())
}
func TestMiddleware_InteractionLogger_RecordsSkippedTools(t *testing.T) {
// Even skip-list tools should be recorded in the interaction log
// (skip-list only controls context injection, not logging)
provider := newMockProvider()
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
logger := &mockInteractionLogger{}
engine.SetInteractionLogger(logger)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "ok"},
},
}, nil
}
wrapped := engine.Middleware()(handler)
// "health" is in the skip list
req := mcp.CallToolRequest{}
req.Params.Name = "health"
_, err := wrapped(context.Background(), req)
require.NoError(t, err)
assert.Equal(t, 1, logger.count(), "skip-list tools should still be logged")
assert.Equal(t, "health", logger.lastToolName())
}
func TestMiddleware_InteractionLogger_ErrorDoesNotBreakHandler(t *testing.T) {
// Logger errors must be swallowed — never break the tool call
provider := newMockProvider()
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
logger := &mockInteractionLogger{failErr: fmt.Errorf("disk full")}
engine.SetInteractionLogger(logger)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "handler succeeded"},
},
}, nil
}
wrapped := engine.Middleware()(handler)
req := mcp.CallToolRequest{}
req.Params.Name = "add_fact"
result, err := wrapped(context.Background(), req)
require.NoError(t, err, "logger error must not propagate")
require.NotNil(t, result)
text := result.Content[0].(mcp.TextContent).Text
assert.Contains(t, text, "handler succeeded")
}
func TestMiddleware_NoLogger_StillWorks(t *testing.T) {
// Without a logger set, middleware should work normally
provider := newMockProvider()
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
// engine.logger is nil — no SetInteractionLogger call
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "no logger ok"},
},
}, nil
}
wrapped := engine.Middleware()(handler)
req := mcp.CallToolRequest{}
req.Params.Name = "add_fact"
result, err := wrapped(context.Background(), req)
require.NoError(t, err)
require.NotNil(t, result)
}
func TestMiddleware_InteractionLogger_MultipleToolCalls(t *testing.T) {
provider := newMockProvider()
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
logger := &mockInteractionLogger{}
engine.SetInteractionLogger(logger)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "ok"},
},
}, nil
}
wrapped := engine.Middleware()(handler)
toolNames := []string{"add_fact", "search_facts", "health", "add_causal_node", "version"}
for _, name := range toolNames {
req := mcp.CallToolRequest{}
req.Params.Name = name
_, _ = wrapped(context.Background(), req)
}
assert.Equal(t, 5, logger.count(), "all 5 tool calls should be logged")
}
func TestMiddleware_InteractionLogger_ConcurrentCalls(t *testing.T) {
provider := newMockProvider()
cfg := ctxdomain.DefaultEngineConfig()
engine := New(cfg, provider)
logger := &mockInteractionLogger{}
engine.SetInteractionLogger(logger)
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "ok"},
},
}, nil
}
wrapped := engine.Middleware()(handler)
var wg sync.WaitGroup
for i := 0; i < 20; i++ {
wg.Add(1)
go func(n int) {
defer wg.Done()
req := mcp.CallToolRequest{}
req.Params.Name = fmt.Sprintf("tool_%d", n)
_, _ = wrapped(context.Background(), req)
}(i)
}
wg.Wait()
assert.Equal(t, 20, logger.count(), "all 20 concurrent calls should be logged")
}

View file

@ -0,0 +1,245 @@
// Package contextengine — processor.go
// Processes unprocessed interaction log entries into session summary facts.
// This closes the memory loop: tool calls → interaction log → summary facts → boot instructions.
package contextengine
import (
"context"
"fmt"
"sort"
"strings"
"time"
"github.com/sentinel-community/gomcp/internal/domain/memory"
"github.com/sentinel-community/gomcp/internal/infrastructure/sqlite"
)
// InteractionProcessor processes unprocessed interaction log entries
// and creates session summary facts from them.
type InteractionProcessor struct {
repo *sqlite.InteractionLogRepo
factStore memory.FactStore
}
// NewInteractionProcessor creates a new processor.
func NewInteractionProcessor(repo *sqlite.InteractionLogRepo, store memory.FactStore) *InteractionProcessor {
return &InteractionProcessor{repo: repo, factStore: store}
}
// ProcessStartup processes unprocessed entries from a previous (possibly crashed) session.
// It creates an L1 "session summary" fact and marks all entries as processed.
// Returns the summary text (empty if nothing to process).
func (p *InteractionProcessor) ProcessStartup(ctx context.Context) (string, error) {
entries, err := p.repo.GetUnprocessed(ctx)
if err != nil {
return "", fmt.Errorf("get unprocessed: %w", err)
}
if len(entries) == 0 {
return "", nil
}
summary := buildSessionSummary(entries, "previous session (recovered)")
if summary == "" {
return "", nil
}
// Save as L1 fact (domain-level, not project-level)
fact := memory.NewFact(summary, memory.LevelDomain, "session-history", "interaction-processor")
fact.Source = "auto:interaction-processor"
if err := p.factStore.Add(ctx, fact); err != nil {
return "", fmt.Errorf("save session summary fact: %w", err)
}
// Mark all as processed
ids := make([]int64, len(entries))
for i, e := range entries {
ids[i] = e.ID
}
if err := p.repo.MarkProcessed(ctx, ids); err != nil {
return "", fmt.Errorf("mark processed: %w", err)
}
return summary, nil
}
// ProcessShutdown processes entries from the current session at graceful shutdown.
// Similar to ProcessStartup but labels differently.
func (p *InteractionProcessor) ProcessShutdown(ctx context.Context) (string, error) {
entries, err := p.repo.GetUnprocessed(ctx)
if err != nil {
return "", fmt.Errorf("get unprocessed: %w", err)
}
if len(entries) == 0 {
return "", nil
}
summary := buildSessionSummary(entries, "session ending "+time.Now().Format("2006-01-02 15:04"))
if summary == "" {
return "", nil
}
fact := memory.NewFact(summary, memory.LevelDomain, "session-history", "interaction-processor")
fact.Source = "auto:session-shutdown"
if err := p.factStore.Add(ctx, fact); err != nil {
return "", fmt.Errorf("save session summary fact: %w", err)
}
ids := make([]int64, len(entries))
for i, e := range entries {
ids[i] = e.ID
}
if err := p.repo.MarkProcessed(ctx, ids); err != nil {
return "", fmt.Errorf("mark processed: %w", err)
}
return summary, nil
}
// buildSessionSummary creates a compact text summary from interaction log entries.
func buildSessionSummary(entries []sqlite.InteractionEntry, label string) string {
if len(entries) == 0 {
return ""
}
// Count tool calls
toolCounts := make(map[string]int)
for _, e := range entries {
toolCounts[e.ToolName]++
}
// Sort by count descending
type toolStat struct {
name string
count int
}
stats := make([]toolStat, 0, len(toolCounts))
for name, count := range toolCounts {
stats = append(stats, toolStat{name, count})
}
sort.Slice(stats, func(i, j int) bool { return stats[i].count > stats[j].count })
// Extract topics from args (unique string values)
topics := extractTopicsFromEntries(entries)
// Time range
var earliest, latest time.Time
for _, e := range entries {
if earliest.IsZero() || e.Timestamp.Before(earliest) {
earliest = e.Timestamp
}
if latest.IsZero() || e.Timestamp.After(latest) {
latest = e.Timestamp
}
}
var b strings.Builder
b.WriteString(fmt.Sprintf("Session summary (%s): %d tool calls", label, len(entries)))
if !earliest.IsZero() {
duration := latest.Sub(earliest)
if duration > 0 {
b.WriteString(fmt.Sprintf(" over %s", formatDuration(duration)))
}
}
b.WriteString(". ")
// Top tools used
b.WriteString("Tools used: ")
for i, ts := range stats {
if i >= 8 {
b.WriteString(fmt.Sprintf(" +%d more", len(stats)-8))
break
}
if i > 0 {
b.WriteString(", ")
}
b.WriteString(fmt.Sprintf("%s(%d)", ts.name, ts.count))
}
b.WriteString(". ")
// Topics
if len(topics) > 0 {
b.WriteString("Topics: ")
limit := 10
if len(topics) < limit {
limit = len(topics)
}
b.WriteString(strings.Join(topics[:limit], ", "))
if len(topics) > limit {
b.WriteString(fmt.Sprintf(" +%d more", len(topics)-limit))
}
b.WriteString(".")
}
return b.String()
}
// extractTopicsFromEntries pulls unique meaningful strings from tool arguments.
func extractTopicsFromEntries(entries []sqlite.InteractionEntry) []string {
seen := make(map[string]bool)
var topics []string
for _, e := range entries {
if e.ArgsJSON == "" {
continue
}
// Simple extraction: find quoted strings in JSON args
// ArgsJSON looks like {"query":"architecture","content":"some fact"}
parts := strings.Split(e.ArgsJSON, "\"")
for i := 3; i < len(parts); i += 4 {
// Values are at odd positions after the key
val := parts[i]
if len(val) < 3 || len(val) > 100 {
continue
}
// Skip common non-topic values
lower := strings.ToLower(val)
if lower == "true" || lower == "false" || lower == "null" || lower == "" {
continue
}
if !seen[lower] {
seen[lower] = true
topics = append(topics, val)
}
}
}
return topics
}
// formatDuration formats a duration into a human-readable string.
func formatDuration(d time.Duration) string {
if d < time.Minute {
return fmt.Sprintf("%ds", int(d.Seconds()))
}
if d < time.Hour {
return fmt.Sprintf("%dm", int(d.Minutes()))
}
return fmt.Sprintf("%dh%dm", int(d.Hours()), int(d.Minutes())%60)
}
// GetLastSessionSummary searches the fact store for the most recent session summary.
func GetLastSessionSummary(ctx context.Context, store memory.FactStore) string {
facts, err := store.Search(ctx, "Session summary", 5)
if err != nil || len(facts) == 0 {
return ""
}
// Find the most recent one from session-history domain
var best *memory.Fact
for _, f := range facts {
if f.Domain != "session-history" {
continue
}
if f.IsStale || f.IsArchived {
continue
}
if best == nil || f.CreatedAt.After(best.CreatedAt) {
best = f
}
}
if best == nil {
return ""
}
return best.Content
}

View file

@ -0,0 +1,251 @@
package contextengine
import (
"context"
"testing"
"time"
"github.com/sentinel-community/gomcp/internal/domain/memory"
"github.com/sentinel-community/gomcp/internal/infrastructure/sqlite"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- mock FactStore for processor tests ---
type procMockFactStore struct {
facts []*memory.Fact
}
func (m *procMockFactStore) Add(_ context.Context, f *memory.Fact) error {
m.facts = append(m.facts, f)
return nil
}
func (m *procMockFactStore) Get(_ context.Context, id string) (*memory.Fact, error) {
for _, f := range m.facts {
if f.ID == id {
return f, nil
}
}
return nil, nil
}
func (m *procMockFactStore) Update(_ context.Context, _ *memory.Fact) error { return nil }
func (m *procMockFactStore) Delete(_ context.Context, _ string) error { return nil }
func (m *procMockFactStore) ListByDomain(_ context.Context, _ string, _ bool) ([]*memory.Fact, error) {
return nil, nil
}
func (m *procMockFactStore) ListByLevel(_ context.Context, _ memory.HierLevel) ([]*memory.Fact, error) {
return nil, nil
}
func (m *procMockFactStore) ListDomains(_ context.Context) ([]string, error) { return nil, nil }
func (m *procMockFactStore) GetStale(_ context.Context, _ bool) ([]*memory.Fact, error) {
return nil, nil
}
func (m *procMockFactStore) Search(_ context.Context, query string, limit int) ([]*memory.Fact, error) {
var results []*memory.Fact
for _, f := range m.facts {
if len(results) >= limit {
break
}
if contains(f.Content, query) {
results = append(results, f)
}
}
return results, nil
}
func (m *procMockFactStore) GetExpired(_ context.Context) ([]*memory.Fact, error) { return nil, nil }
func (m *procMockFactStore) RefreshTTL(_ context.Context, _ string) error { return nil }
func (m *procMockFactStore) TouchFact(_ context.Context, _ string) error { return nil }
func (m *procMockFactStore) GetColdFacts(_ context.Context, _ int) ([]*memory.Fact, error) {
return nil, nil
}
func (m *procMockFactStore) CompressFacts(_ context.Context, _ []string, _ string) (string, error) {
return "", nil
}
func (m *procMockFactStore) Stats(_ context.Context) (*memory.FactStoreStats, error) {
return &memory.FactStoreStats{}, nil
}
func (m *procMockFactStore) ListGenes(_ context.Context) ([]*memory.Fact, error) { return nil, nil }
func contains(s, sub string) bool {
return len(s) >= len(sub) && (s == sub || len(sub) == 0 ||
(len(s) > 0 && len(sub) > 0 && containsStr(s, sub)))
}
func containsStr(s, sub string) bool {
for i := 0; i <= len(s)-len(sub); i++ {
if s[i:i+len(sub)] == sub {
return true
}
}
return false
}
// --- Tests ---
func TestInteractionProcessor_ProcessStartup_NoEntries(t *testing.T) {
db, err := sqlite.OpenMemory()
require.NoError(t, err)
defer db.Close()
repo, err := sqlite.NewInteractionLogRepo(db)
require.NoError(t, err)
store := &procMockFactStore{}
proc := NewInteractionProcessor(repo, store)
summary, err := proc.ProcessStartup(context.Background())
require.NoError(t, err)
assert.Empty(t, summary)
assert.Empty(t, store.facts, "no facts should be created")
}
func TestInteractionProcessor_ProcessStartup_CreatesFactAndMarksProcessed(t *testing.T) {
db, err := sqlite.OpenMemory()
require.NoError(t, err)
defer db.Close()
repo, err := sqlite.NewInteractionLogRepo(db)
require.NoError(t, err)
ctx := context.Background()
// Insert some tool calls
require.NoError(t, repo.Record(ctx, "add_fact", map[string]interface{}{"content": "test fact about architecture"}))
require.NoError(t, repo.Record(ctx, "search_facts", map[string]interface{}{"query": "security"}))
require.NoError(t, repo.Record(ctx, "health", nil))
require.NoError(t, repo.Record(ctx, "add_fact", map[string]interface{}{"content": "another fact"}))
require.NoError(t, repo.Record(ctx, "dashboard", nil))
store := &procMockFactStore{}
proc := NewInteractionProcessor(repo, store)
summary, err := proc.ProcessStartup(ctx)
require.NoError(t, err)
assert.NotEmpty(t, summary)
assert.Contains(t, summary, "Session summary")
assert.Contains(t, summary, "5 tool calls")
assert.Contains(t, summary, "add_fact(2)")
// Fact should be saved
require.Len(t, store.facts, 1)
assert.Equal(t, memory.LevelDomain, store.facts[0].Level)
assert.Equal(t, "session-history", store.facts[0].Domain)
assert.Equal(t, "auto:interaction-processor", store.facts[0].Source)
// All entries should be marked processed
_, unprocessed, err := repo.Count(ctx)
require.NoError(t, err)
assert.Equal(t, 0, unprocessed)
}
func TestInteractionProcessor_ProcessShutdown(t *testing.T) {
db, err := sqlite.OpenMemory()
require.NoError(t, err)
defer db.Close()
repo, err := sqlite.NewInteractionLogRepo(db)
require.NoError(t, err)
ctx := context.Background()
require.NoError(t, repo.Record(ctx, "version", nil))
require.NoError(t, repo.Record(ctx, "search_facts", map[string]interface{}{"query": "gomcp"}))
store := &procMockFactStore{}
proc := NewInteractionProcessor(repo, store)
summary, err := proc.ProcessShutdown(ctx)
require.NoError(t, err)
assert.Contains(t, summary, "session ending")
assert.Contains(t, summary, "2 tool calls")
require.Len(t, store.facts, 1)
assert.Equal(t, "auto:session-shutdown", store.facts[0].Source)
}
func TestBuildSessionSummary_ToolCounts(t *testing.T) {
now := time.Now()
entries := []sqlite.InteractionEntry{
{ID: 1, ToolName: "add_fact", Timestamp: now},
{ID: 2, ToolName: "add_fact", Timestamp: now},
{ID: 3, ToolName: "add_fact", Timestamp: now},
{ID: 4, ToolName: "search_facts", Timestamp: now},
{ID: 5, ToolName: "health", Timestamp: now},
}
summary := buildSessionSummary(entries, "test")
assert.Contains(t, summary, "5 tool calls")
assert.Contains(t, summary, "add_fact(3)")
assert.Contains(t, summary, "search_facts(1)")
assert.Contains(t, summary, "health(1)")
}
func TestBuildSessionSummary_Duration(t *testing.T) {
now := time.Now()
entries := []sqlite.InteractionEntry{
{ID: 1, ToolName: "a", Timestamp: now.Add(-30 * time.Minute)},
{ID: 2, ToolName: "b", Timestamp: now},
}
summary := buildSessionSummary(entries, "test")
assert.Contains(t, summary, "30m")
}
func TestBuildSessionSummary_Empty(t *testing.T) {
summary := buildSessionSummary(nil, "test")
assert.Empty(t, summary)
}
func TestBuildSessionSummary_Topics(t *testing.T) {
now := time.Now()
entries := []sqlite.InteractionEntry{
{ID: 1, ToolName: "search_facts", ArgsJSON: `{"query":"architecture"}`, Timestamp: now},
{ID: 2, ToolName: "add_fact", ArgsJSON: `{"content":"security review"}`, Timestamp: now},
}
summary := buildSessionSummary(entries, "test")
assert.Contains(t, summary, "Topics:")
assert.Contains(t, summary, "architecture")
}
func TestGetLastSessionSummary_Found(t *testing.T) {
store := &procMockFactStore{}
f := memory.NewFact("Session summary (test): 5 tool calls", memory.LevelDomain, "session-history", "")
f.Source = "auto:session-shutdown"
store.facts = append(store.facts, f)
result := GetLastSessionSummary(context.Background(), store)
assert.Contains(t, result, "Session summary")
}
func TestGetLastSessionSummary_NotFound(t *testing.T) {
store := &procMockFactStore{}
result := GetLastSessionSummary(context.Background(), store)
assert.Empty(t, result)
}
func TestGetLastSessionSummary_SkipsStale(t *testing.T) {
store := &procMockFactStore{}
f := memory.NewFact("Session summary (test): old", memory.LevelDomain, "session-history", "")
f.IsStale = true
store.facts = append(store.facts, f)
result := GetLastSessionSummary(context.Background(), store)
assert.Empty(t, result)
}
func TestFormatDuration(t *testing.T) {
assert.Equal(t, "30s", formatDuration(30*time.Second))
assert.Equal(t, "5m", formatDuration(5*time.Minute))
assert.Equal(t, "2h15m", formatDuration(2*time.Hour+15*time.Minute))
}
func TestExtractTopicsFromEntries(t *testing.T) {
entries := []sqlite.InteractionEntry{
{ArgsJSON: `{"query":"architecture"}`},
{ArgsJSON: `{"content":"security review"}`},
{ArgsJSON: ``}, // empty
}
topics := extractTopicsFromEntries(entries)
assert.NotEmpty(t, topics)
}

View file

@ -0,0 +1,117 @@
package contextengine
import (
"context"
"sync"
"github.com/sentinel-community/gomcp/internal/domain/memory"
ctxdomain "github.com/sentinel-community/gomcp/internal/domain/context"
)
// StoreFactProvider adapts FactStore + HotCache to the FactProvider interface,
// bridging infrastructure storage with the context engine domain.
type StoreFactProvider struct {
store memory.FactStore
cache memory.HotCache
mu sync.Mutex
accessCounts map[string]int
}
// NewStoreFactProvider creates a FactProvider backed by FactStore and optional HotCache.
func NewStoreFactProvider(store memory.FactStore, cache memory.HotCache) *StoreFactProvider {
return &StoreFactProvider{
store: store,
cache: cache,
accessCounts: make(map[string]int),
}
}
// Verify interface compliance at compile time.
var _ ctxdomain.FactProvider = (*StoreFactProvider)(nil)
// GetRelevantFacts returns candidate facts for context injection.
// Uses keyword search from tool arguments + L0 facts as candidates.
func (p *StoreFactProvider) GetRelevantFacts(args map[string]interface{}) ([]*memory.Fact, error) {
ctx := context.Background()
// Always include L0 facts
l0Facts, err := p.GetL0Facts()
if err != nil {
return nil, err
}
// Extract query text from arguments for search
query := extractQueryFromArgs(args)
if query == "" {
return l0Facts, nil
}
// Search for additional relevant facts
searchResults, err := p.store.Search(ctx, query, 30)
if err != nil {
// Degrade gracefully — just return L0 facts
return l0Facts, nil
}
// Merge L0 + search results, deduplicating by ID
seen := make(map[string]bool, len(l0Facts))
merged := make([]*memory.Fact, 0, len(l0Facts)+len(searchResults))
for _, f := range l0Facts {
seen[f.ID] = true
merged = append(merged, f)
}
for _, f := range searchResults {
if !seen[f.ID] {
seen[f.ID] = true
merged = append(merged, f)
}
}
return merged, nil
}
// GetL0Facts returns all L0 (project-level) facts.
// Uses HotCache if available, falls back to store.
func (p *StoreFactProvider) GetL0Facts() ([]*memory.Fact, error) {
ctx := context.Background()
if p.cache != nil {
facts, err := p.cache.GetL0Facts(ctx)
if err == nil && len(facts) > 0 {
return facts, nil
}
}
return p.store.ListByLevel(ctx, memory.LevelProject)
}
// RecordAccess increments the access counter for a fact.
func (p *StoreFactProvider) RecordAccess(factID string) {
p.mu.Lock()
p.accessCounts[factID]++
p.mu.Unlock()
}
// extractQueryFromArgs builds a search query string from argument values.
func extractQueryFromArgs(args map[string]interface{}) string {
var parts []string
for _, v := range args {
if s, ok := v.(string); ok && s != "" {
parts = append(parts, s)
}
}
if len(parts) == 0 {
return ""
}
result := ""
for i, p := range parts {
if i > 0 {
result += " "
}
result += p
}
return result
}

View file

@ -0,0 +1,278 @@
package contextengine
import (
"context"
"fmt"
"sync"
"testing"
"github.com/sentinel-community/gomcp/internal/domain/memory"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Mock FactStore for provider tests ---
type mockFactStore struct {
mu sync.Mutex
facts map[string]*memory.Fact
searchFacts []*memory.Fact
searchErr error
levelFacts map[memory.HierLevel][]*memory.Fact
}
func newMockFactStore() *mockFactStore {
return &mockFactStore{
facts: make(map[string]*memory.Fact),
levelFacts: make(map[memory.HierLevel][]*memory.Fact),
}
}
func (m *mockFactStore) Add(_ context.Context, fact *memory.Fact) error {
m.mu.Lock()
defer m.mu.Unlock()
m.facts[fact.ID] = fact
m.levelFacts[fact.Level] = append(m.levelFacts[fact.Level], fact)
return nil
}
func (m *mockFactStore) Get(_ context.Context, id string) (*memory.Fact, error) {
m.mu.Lock()
defer m.mu.Unlock()
f, ok := m.facts[id]
if !ok {
return nil, fmt.Errorf("not found")
}
return f, nil
}
func (m *mockFactStore) Update(_ context.Context, fact *memory.Fact) error {
m.mu.Lock()
defer m.mu.Unlock()
m.facts[fact.ID] = fact
return nil
}
func (m *mockFactStore) Delete(_ context.Context, id string) error {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.facts, id)
return nil
}
func (m *mockFactStore) ListByDomain(_ context.Context, _ string, _ bool) ([]*memory.Fact, error) {
return nil, nil
}
func (m *mockFactStore) ListByLevel(_ context.Context, level memory.HierLevel) ([]*memory.Fact, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.levelFacts[level], nil
}
func (m *mockFactStore) ListDomains(_ context.Context) ([]string, error) {
return nil, nil
}
func (m *mockFactStore) GetStale(_ context.Context, _ bool) ([]*memory.Fact, error) {
return nil, nil
}
func (m *mockFactStore) Search(_ context.Context, _ string, _ int) ([]*memory.Fact, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.searchFacts, m.searchErr
}
func (m *mockFactStore) GetExpired(_ context.Context) ([]*memory.Fact, error) {
return nil, nil
}
func (m *mockFactStore) RefreshTTL(_ context.Context, _ string) error {
return nil
}
func (m *mockFactStore) TouchFact(_ context.Context, _ string) error { return nil }
func (m *mockFactStore) GetColdFacts(_ context.Context, _ int) ([]*memory.Fact, error) {
return nil, nil
}
func (m *mockFactStore) CompressFacts(_ context.Context, _ []string, _ string) (string, error) {
return "", nil
}
func (m *mockFactStore) Stats(_ context.Context) (*memory.FactStoreStats, error) {
return nil, nil
}
func (m *mockFactStore) ListGenes(_ context.Context) ([]*memory.Fact, error) { return nil, nil }
// --- Mock HotCache ---
type mockHotCache struct {
l0Facts []*memory.Fact
l0Err error
}
func (m *mockHotCache) GetL0Facts(_ context.Context) ([]*memory.Fact, error) {
return m.l0Facts, m.l0Err
}
func (m *mockHotCache) InvalidateFact(_ context.Context, _ string) error { return nil }
func (m *mockHotCache) WarmUp(_ context.Context, _ []*memory.Fact) error { return nil }
func (m *mockHotCache) Close() error { return nil }
// --- StoreFactProvider tests ---
func TestNewStoreFactProvider(t *testing.T) {
store := newMockFactStore()
provider := NewStoreFactProvider(store, nil)
require.NotNil(t, provider)
}
func TestStoreFactProvider_GetL0Facts_FromStore(t *testing.T) {
store := newMockFactStore()
f1 := memory.NewFact("L0 fact A", memory.LevelProject, "arch", "")
f2 := memory.NewFact("L0 fact B", memory.LevelProject, "process", "")
_ = store.Add(context.Background(), f1)
_ = store.Add(context.Background(), f2)
provider := NewStoreFactProvider(store, nil) // no cache
facts, err := provider.GetL0Facts()
require.NoError(t, err)
assert.Len(t, facts, 2)
}
func TestStoreFactProvider_GetL0Facts_FromCache(t *testing.T) {
store := newMockFactStore()
cacheFact := memory.NewFact("Cached L0", memory.LevelProject, "arch", "")
cache := &mockHotCache{l0Facts: []*memory.Fact{cacheFact}}
provider := NewStoreFactProvider(store, cache)
facts, err := provider.GetL0Facts()
require.NoError(t, err)
assert.Len(t, facts, 1)
assert.Equal(t, "Cached L0", facts[0].Content)
}
func TestStoreFactProvider_GetL0Facts_CacheFallbackToStore(t *testing.T) {
store := newMockFactStore()
storeFact := memory.NewFact("Store L0", memory.LevelProject, "arch", "")
_ = store.Add(context.Background(), storeFact)
cache := &mockHotCache{l0Facts: nil, l0Err: fmt.Errorf("cache miss")}
provider := NewStoreFactProvider(store, cache)
facts, err := provider.GetL0Facts()
require.NoError(t, err)
assert.Len(t, facts, 1)
assert.Equal(t, "Store L0", facts[0].Content)
}
func TestStoreFactProvider_GetRelevantFacts_NoQuery(t *testing.T) {
store := newMockFactStore()
l0 := memory.NewFact("L0 always included", memory.LevelProject, "arch", "")
_ = store.Add(context.Background(), l0)
provider := NewStoreFactProvider(store, nil)
// No string args → no query → only L0
facts, err := provider.GetRelevantFacts(map[string]interface{}{
"level": 0,
})
require.NoError(t, err)
assert.Len(t, facts, 1)
}
func TestStoreFactProvider_GetRelevantFacts_WithSearch(t *testing.T) {
store := newMockFactStore()
l0 := memory.NewFact("L0 architecture", memory.LevelProject, "arch", "")
searchResult := memory.NewFact("Found by search", memory.LevelDomain, "auth", "")
_ = store.Add(context.Background(), l0)
store.searchFacts = []*memory.Fact{searchResult}
provider := NewStoreFactProvider(store, nil)
facts, err := provider.GetRelevantFacts(map[string]interface{}{
"content": "authentication module",
})
require.NoError(t, err)
assert.Len(t, facts, 2) // L0 + search result
}
func TestStoreFactProvider_GetRelevantFacts_Deduplication(t *testing.T) {
store := newMockFactStore()
l0 := memory.NewFact("L0 architecture", memory.LevelProject, "arch", "")
_ = store.Add(context.Background(), l0)
// Search returns the same fact that's also L0
store.searchFacts = []*memory.Fact{l0}
provider := NewStoreFactProvider(store, nil)
facts, err := provider.GetRelevantFacts(map[string]interface{}{
"content": "architecture",
})
require.NoError(t, err)
assert.Len(t, facts, 1, "duplicate should be removed")
}
func TestStoreFactProvider_GetRelevantFacts_SearchError_GracefulDegradation(t *testing.T) {
store := newMockFactStore()
l0 := memory.NewFact("L0 fact", memory.LevelProject, "arch", "")
_ = store.Add(context.Background(), l0)
store.searchErr = fmt.Errorf("search broken")
provider := NewStoreFactProvider(store, nil)
facts, err := provider.GetRelevantFacts(map[string]interface{}{
"content": "test query",
})
require.NoError(t, err, "should degrade gracefully, not error")
assert.Len(t, facts, 1, "should still return L0 facts")
}
func TestStoreFactProvider_RecordAccess(t *testing.T) {
store := newMockFactStore()
provider := NewStoreFactProvider(store, nil)
provider.RecordAccess("fact-1")
provider.RecordAccess("fact-1")
provider.RecordAccess("fact-2")
provider.mu.Lock()
assert.Equal(t, 2, provider.accessCounts["fact-1"])
assert.Equal(t, 1, provider.accessCounts["fact-2"])
provider.mu.Unlock()
}
func TestExtractQueryFromArgs(t *testing.T) {
tests := []struct {
name string
args map[string]interface{}
want string
}{
{"nil", nil, ""},
{"empty", map[string]interface{}{}, ""},
{"no strings", map[string]interface{}{"level": 0, "flag": true}, ""},
{"single string", map[string]interface{}{"content": "hello"}, "hello"},
{"empty string", map[string]interface{}{"content": ""}, ""},
{"mixed", map[string]interface{}{"content": "hello", "level": 0, "domain": "arch"}, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractQueryFromArgs(tt.args)
if tt.name == "mixed" {
// Map iteration order is non-deterministic, just check non-empty
assert.NotEmpty(t, got)
} else {
if tt.want == "" {
assert.Empty(t, got)
} else {
assert.Contains(t, got, tt.want)
}
}
})
}
}