mirror of
https://github.com/syntrex-lab/gomcp.git
synced 2026-05-04 16:52:36 +02:00
initial: Syntrex extraction from sentinel-community (615 files)
This commit is contained in:
commit
2c50c993b1
175 changed files with 32396 additions and 0 deletions
52
internal/application/contextengine/config.go
Normal file
52
internal/application/contextengine/config.go
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
package contextengine
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
|
||||
ctxdomain "github.com/sentinel-community/gomcp/internal/domain/context"
|
||||
)
|
||||
|
||||
// LoadConfig loads engine configuration from a JSON file.
|
||||
// If the file does not exist, returns DefaultEngineConfig.
|
||||
// If the file exists but is invalid, returns an error.
|
||||
func LoadConfig(path string) (ctxdomain.EngineConfig, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return ctxdomain.DefaultEngineConfig(), nil
|
||||
}
|
||||
return ctxdomain.EngineConfig{}, err
|
||||
}
|
||||
|
||||
var cfg ctxdomain.EngineConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return ctxdomain.EngineConfig{}, err
|
||||
}
|
||||
|
||||
// Build skip set from deserialized SkipTools slice.
|
||||
cfg.BuildSkipSet()
|
||||
|
||||
// If skip_tools was omitted in JSON, use defaults.
|
||||
if cfg.SkipTools == nil {
|
||||
cfg.SkipTools = ctxdomain.DefaultSkipTools()
|
||||
cfg.BuildSkipSet()
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return ctxdomain.EngineConfig{}, err
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// SaveDefaultConfig writes the default configuration to a JSON file.
|
||||
// Useful for bootstrapping .rlm/context.json.
|
||||
func SaveDefaultConfig(path string) error {
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, 0o644)
|
||||
}
|
||||
110
internal/application/contextengine/config_test.go
Normal file
110
internal/application/contextengine/config_test.go
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
package contextengine
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
ctxdomain "github.com/sentinel-community/gomcp/internal/domain/context"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLoadConfig_FileNotExists(t *testing.T) {
|
||||
cfg, err := LoadConfig("/nonexistent/path/context.json")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ctxdomain.DefaultTokenBudget, cfg.TokenBudget)
|
||||
assert.True(t, cfg.Enabled)
|
||||
assert.NotEmpty(t, cfg.SkipTools)
|
||||
}
|
||||
|
||||
func TestLoadConfig_ValidFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "context.json")
|
||||
|
||||
content := `{
|
||||
"token_budget": 500,
|
||||
"max_facts": 15,
|
||||
"recency_weight": 0.3,
|
||||
"frequency_weight": 0.2,
|
||||
"level_weight": 0.25,
|
||||
"keyword_weight": 0.25,
|
||||
"decay_half_life_hours": 48,
|
||||
"enabled": true,
|
||||
"skip_tools": ["health", "version"]
|
||||
}`
|
||||
err := os.WriteFile(path, []byte(content), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := LoadConfig(path)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 500, cfg.TokenBudget)
|
||||
assert.Equal(t, 15, cfg.MaxFacts)
|
||||
assert.Equal(t, 0.3, cfg.RecencyWeight)
|
||||
assert.Equal(t, 48.0, cfg.DecayHalfLifeHours)
|
||||
assert.True(t, cfg.Enabled)
|
||||
assert.Len(t, cfg.SkipTools, 2)
|
||||
assert.True(t, cfg.ShouldSkip("health"))
|
||||
assert.True(t, cfg.ShouldSkip("version"))
|
||||
assert.False(t, cfg.ShouldSkip("search_facts"))
|
||||
}
|
||||
|
||||
func TestLoadConfig_InvalidJSON(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "context.json")
|
||||
|
||||
err := os.WriteFile(path, []byte("{invalid json"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = LoadConfig(path)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestLoadConfig_InvalidConfig(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "context.json")
|
||||
|
||||
content := `{"token_budget": 0, "max_facts": 5, "recency_weight": 0.5, "frequency_weight": 0.5, "level_weight": 0, "keyword_weight": 0, "decay_half_life_hours": 24, "enabled": true}`
|
||||
err := os.WriteFile(path, []byte(content), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = LoadConfig(path)
|
||||
assert.Error(t, err, "should fail validation: token_budget=0")
|
||||
}
|
||||
|
||||
func TestLoadConfig_OmittedSkipTools_UsesDefaults(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "context.json")
|
||||
|
||||
content := `{
|
||||
"token_budget": 300,
|
||||
"max_facts": 10,
|
||||
"recency_weight": 0.25,
|
||||
"frequency_weight": 0.15,
|
||||
"level_weight": 0.30,
|
||||
"keyword_weight": 0.30,
|
||||
"decay_half_life_hours": 72,
|
||||
"enabled": true
|
||||
}`
|
||||
err := os.WriteFile(path, []byte(content), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := LoadConfig(path)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, cfg.SkipTools, "omitted skip_tools should use defaults")
|
||||
assert.True(t, cfg.ShouldSkip("search_facts"))
|
||||
}
|
||||
|
||||
func TestSaveDefaultConfig(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "context.json")
|
||||
|
||||
err := SaveDefaultConfig(path)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify we can load what we saved
|
||||
cfg, err := LoadConfig(path)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ctxdomain.DefaultTokenBudget, cfg.TokenBudget)
|
||||
assert.True(t, cfg.Enabled)
|
||||
}
|
||||
203
internal/application/contextengine/engine.go
Normal file
203
internal/application/contextengine/engine.go
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
// Package contextengine implements the Proactive Context Engine.
|
||||
// It automatically injects relevant memory facts into every MCP tool response
|
||||
// via ToolHandlerMiddleware, so the LLM always has context without asking.
|
||||
package contextengine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
|
||||
ctxdomain "github.com/sentinel-community/gomcp/internal/domain/context"
|
||||
)
|
||||
|
||||
// InteractionLogger records tool calls for crash-safe memory.
|
||||
// Implementations must be safe for concurrent use.
|
||||
type InteractionLogger interface {
|
||||
Record(ctx context.Context, toolName string, args map[string]interface{}) error
|
||||
}
|
||||
|
||||
// Engine is the Proactive Context Engine. It scores facts by relevance,
|
||||
// selects the top ones within a token budget, and injects them into
|
||||
// every tool response as a [MEMORY CONTEXT] block.
|
||||
type Engine struct {
|
||||
config ctxdomain.EngineConfig
|
||||
scorer *ctxdomain.RelevanceScorer
|
||||
provider ctxdomain.FactProvider
|
||||
logger InteractionLogger // optional, nil = no logging
|
||||
|
||||
mu sync.RWMutex
|
||||
accessCounts map[string]int // in-memory access counters per fact ID
|
||||
}
|
||||
|
||||
// New creates a new Proactive Context Engine.
|
||||
func New(cfg ctxdomain.EngineConfig, provider ctxdomain.FactProvider) *Engine {
|
||||
return &Engine{
|
||||
config: cfg,
|
||||
scorer: ctxdomain.NewRelevanceScorer(cfg),
|
||||
provider: provider,
|
||||
accessCounts: make(map[string]int),
|
||||
}
|
||||
}
|
||||
|
||||
// SetInteractionLogger attaches an optional interaction logger for crash-safe
|
||||
// tool call recording. If set, every tool call passing through the middleware
|
||||
// will be recorded fire-and-forget (errors logged, never propagated).
|
||||
func (e *Engine) SetInteractionLogger(l InteractionLogger) {
|
||||
e.logger = l
|
||||
}
|
||||
|
||||
// IsEnabled returns whether the engine is active.
|
||||
func (e *Engine) IsEnabled() bool {
|
||||
return e.config.Enabled
|
||||
}
|
||||
|
||||
// BuildContext scores and selects relevant facts for the given tool call,
|
||||
// returning a ContextFrame ready for formatting and injection.
|
||||
func (e *Engine) BuildContext(toolName string, args map[string]interface{}) *ctxdomain.ContextFrame {
|
||||
frame := ctxdomain.NewContextFrame(toolName, e.config.TokenBudget)
|
||||
|
||||
if !e.config.Enabled {
|
||||
return frame
|
||||
}
|
||||
|
||||
// Extract keywords from all string arguments
|
||||
keywords := e.extractKeywordsFromArgs(args)
|
||||
|
||||
// Get candidate facts from provider
|
||||
facts, err := e.provider.GetRelevantFacts(args)
|
||||
if err != nil || len(facts) == 0 {
|
||||
return frame
|
||||
}
|
||||
|
||||
// Get current access counts snapshot
|
||||
e.mu.RLock()
|
||||
countsCopy := make(map[string]int, len(e.accessCounts))
|
||||
for k, v := range e.accessCounts {
|
||||
countsCopy[k] = v
|
||||
}
|
||||
e.mu.RUnlock()
|
||||
|
||||
// Score and rank facts
|
||||
ranked := e.scorer.RankFacts(facts, keywords, countsCopy)
|
||||
|
||||
// Fill frame within token budget and max facts
|
||||
added := 0
|
||||
for _, sf := range ranked {
|
||||
if added >= e.config.MaxFacts {
|
||||
break
|
||||
}
|
||||
if frame.AddFact(sf) {
|
||||
added++
|
||||
// Record access for reinforcement
|
||||
e.recordAccessInternal(sf.Fact.ID)
|
||||
e.provider.RecordAccess(sf.Fact.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return frame
|
||||
}
|
||||
|
||||
// GetAccessCount returns the internal access count for a fact.
|
||||
func (e *Engine) GetAccessCount(factID string) int {
|
||||
e.mu.RLock()
|
||||
defer e.mu.RUnlock()
|
||||
return e.accessCounts[factID]
|
||||
}
|
||||
|
||||
// recordAccessInternal increments the in-memory access counter.
|
||||
func (e *Engine) recordAccessInternal(factID string) {
|
||||
e.mu.Lock()
|
||||
e.accessCounts[factID]++
|
||||
e.mu.Unlock()
|
||||
}
|
||||
|
||||
// Middleware returns a ToolHandlerMiddleware that wraps every tool handler
|
||||
// to inject relevant memory context into the response and optionally
|
||||
// record tool calls to the interaction log for crash-safe memory.
|
||||
func (e *Engine) Middleware() server.ToolHandlerMiddleware {
|
||||
return func(next server.ToolHandlerFunc) server.ToolHandlerFunc {
|
||||
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
// Fire-and-forget: record this tool call in the interaction log.
|
||||
// This runs BEFORE the handler so the record is persisted even if
|
||||
// the process is killed mid-handler.
|
||||
if e.logger != nil {
|
||||
if logErr := e.logger.Record(ctx, req.Params.Name, req.GetArguments()); logErr != nil {
|
||||
log.Printf("contextengine: interaction log error: %v", logErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Call the original handler
|
||||
result, err := next(ctx, req)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Don't inject on nil result, error results, or empty content
|
||||
if result == nil || result.IsError || len(result.Content) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Don't inject if engine is disabled
|
||||
if !e.IsEnabled() {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Don't inject for tools in the skip list
|
||||
if e.config.ShouldSkip(req.Params.Name) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Build context frame
|
||||
frame := e.BuildContext(req.Params.Name, req.GetArguments())
|
||||
contextText := frame.Format()
|
||||
if contextText == "" {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Append context to the last text content block
|
||||
e.appendContextToResult(result, contextText)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// appendContextToResult appends the context text to the last TextContent in the result.
|
||||
func (e *Engine) appendContextToResult(result *mcp.CallToolResult, contextText string) {
|
||||
for i := len(result.Content) - 1; i >= 0; i-- {
|
||||
if tc, ok := result.Content[i].(mcp.TextContent); ok {
|
||||
tc.Text += contextText
|
||||
result.Content[i] = tc
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// No text content found — add a new one
|
||||
result.Content = append(result.Content, mcp.TextContent{
|
||||
Type: "text",
|
||||
Text: contextText,
|
||||
})
|
||||
}
|
||||
|
||||
// extractKeywordsFromArgs extracts keywords from all string values in the arguments map.
|
||||
func (e *Engine) extractKeywordsFromArgs(args map[string]interface{}) []string {
|
||||
if len(args) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var allText strings.Builder
|
||||
for _, v := range args {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
allText.WriteString(val)
|
||||
allText.WriteString(" ")
|
||||
}
|
||||
}
|
||||
|
||||
return ctxdomain.ExtractKeywords(allText.String())
|
||||
}
|
||||
708
internal/application/contextengine/engine_test.go
Normal file
708
internal/application/contextengine/engine_test.go
Normal file
|
|
@ -0,0 +1,708 @@
|
|||
package contextengine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/sentinel-community/gomcp/internal/domain/memory"
|
||||
|
||||
ctxdomain "github.com/sentinel-community/gomcp/internal/domain/context"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- Mock FactProvider ---
|
||||
|
||||
type mockProvider struct {
|
||||
mu sync.Mutex
|
||||
facts []*memory.Fact
|
||||
l0 []*memory.Fact
|
||||
// tracks RecordAccess calls
|
||||
accessed map[string]int
|
||||
}
|
||||
|
||||
func newMockProvider(facts ...*memory.Fact) *mockProvider {
|
||||
l0 := make([]*memory.Fact, 0)
|
||||
for _, f := range facts {
|
||||
if f.Level == memory.LevelProject {
|
||||
l0 = append(l0, f)
|
||||
}
|
||||
}
|
||||
return &mockProvider{
|
||||
facts: facts,
|
||||
l0: l0,
|
||||
accessed: make(map[string]int),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockProvider) GetRelevantFacts(_ map[string]interface{}) ([]*memory.Fact, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.facts, nil
|
||||
}
|
||||
|
||||
func (m *mockProvider) GetL0Facts() ([]*memory.Fact, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.l0, nil
|
||||
}
|
||||
|
||||
func (m *mockProvider) RecordAccess(factID string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.accessed[factID]++
|
||||
}
|
||||
|
||||
// --- Engine tests ---
|
||||
|
||||
func TestNewEngine(t *testing.T) {
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
provider := newMockProvider()
|
||||
engine := New(cfg, provider)
|
||||
|
||||
require.NotNil(t, engine)
|
||||
assert.True(t, engine.IsEnabled())
|
||||
}
|
||||
|
||||
func TestNewEngine_Disabled(t *testing.T) {
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
cfg.Enabled = false
|
||||
engine := New(cfg, newMockProvider())
|
||||
|
||||
assert.False(t, engine.IsEnabled())
|
||||
}
|
||||
|
||||
func TestEngine_BuildContext_NoFacts(t *testing.T) {
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
provider := newMockProvider()
|
||||
engine := New(cfg, provider)
|
||||
|
||||
frame := engine.BuildContext("test_tool", map[string]interface{}{
|
||||
"content": "hello world",
|
||||
})
|
||||
|
||||
assert.NotNil(t, frame)
|
||||
assert.Empty(t, frame.Facts)
|
||||
assert.Equal(t, "", frame.Format())
|
||||
}
|
||||
|
||||
func TestEngine_BuildContext_WithFacts(t *testing.T) {
|
||||
fact1 := memory.NewFact("Architecture uses clean layers", memory.LevelProject, "arch", "")
|
||||
fact2 := memory.NewFact("TDD is mandatory for all code", memory.LevelProject, "process", "")
|
||||
fact3 := memory.NewFact("Random snippet from old session", memory.LevelSnippet, "misc", "")
|
||||
fact3.CreatedAt = time.Now().Add(-90 * 24 * time.Hour) // very old
|
||||
|
||||
provider := newMockProvider(fact1, fact2, fact3)
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
cfg.TokenBudget = 500
|
||||
engine := New(cfg, provider)
|
||||
|
||||
frame := engine.BuildContext("add_fact", map[string]interface{}{
|
||||
"content": "architecture decision",
|
||||
})
|
||||
|
||||
require.NotNil(t, frame)
|
||||
assert.NotEmpty(t, frame.Facts)
|
||||
// L0 facts should be included and ranked higher
|
||||
assert.Equal(t, "add_fact", frame.ToolName)
|
||||
|
||||
formatted := frame.Format()
|
||||
assert.Contains(t, formatted, "[MEMORY CONTEXT]")
|
||||
assert.Contains(t, formatted, "[/MEMORY CONTEXT]")
|
||||
}
|
||||
|
||||
func TestEngine_BuildContext_RespectsTokenBudget(t *testing.T) {
|
||||
// Create many facts that exceed token budget
|
||||
facts := make([]*memory.Fact, 50)
|
||||
for i := 0; i < 50; i++ {
|
||||
facts[i] = memory.NewFact(
|
||||
fmt.Sprintf("Fact number %d with enough content to consume tokens in the budget allocation system", i),
|
||||
memory.LevelProject, "arch", "",
|
||||
)
|
||||
}
|
||||
|
||||
provider := newMockProvider(facts...)
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
cfg.TokenBudget = 100 // tight budget
|
||||
cfg.MaxFacts = 50
|
||||
engine := New(cfg, provider)
|
||||
|
||||
frame := engine.BuildContext("test", map[string]interface{}{"query": "test"})
|
||||
assert.LessOrEqual(t, frame.TokensUsed, cfg.TokenBudget)
|
||||
}
|
||||
|
||||
func TestEngine_BuildContext_RespectsMaxFacts(t *testing.T) {
|
||||
facts := make([]*memory.Fact, 20)
|
||||
for i := 0; i < 20; i++ {
|
||||
facts[i] = memory.NewFact(fmt.Sprintf("Fact %d", i), memory.LevelProject, "arch", "")
|
||||
}
|
||||
|
||||
provider := newMockProvider(facts...)
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
cfg.MaxFacts = 5
|
||||
cfg.TokenBudget = 10000 // large budget so max_facts is the limiter
|
||||
engine := New(cfg, provider)
|
||||
|
||||
frame := engine.BuildContext("test", map[string]interface{}{"query": "fact"})
|
||||
assert.LessOrEqual(t, len(frame.Facts), cfg.MaxFacts)
|
||||
}
|
||||
|
||||
func TestEngine_BuildContext_DisabledReturnsEmpty(t *testing.T) {
|
||||
fact := memory.NewFact("test", memory.LevelProject, "arch", "")
|
||||
provider := newMockProvider(fact)
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
cfg.Enabled = false
|
||||
engine := New(cfg, provider)
|
||||
|
||||
frame := engine.BuildContext("test", map[string]interface{}{"content": "test"})
|
||||
assert.Empty(t, frame.Facts)
|
||||
assert.Equal(t, "", frame.Format())
|
||||
}
|
||||
|
||||
func TestEngine_RecordsAccess(t *testing.T) {
|
||||
fact1 := memory.NewFact("Architecture pattern", memory.LevelProject, "arch", "")
|
||||
provider := newMockProvider(fact1)
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
engine := New(cfg, provider)
|
||||
|
||||
frame := engine.BuildContext("test", map[string]interface{}{"content": "architecture"})
|
||||
require.NotEmpty(t, frame.Facts)
|
||||
|
||||
// Check that RecordAccess was called on the provider
|
||||
provider.mu.Lock()
|
||||
count := provider.accessed[fact1.ID]
|
||||
provider.mu.Unlock()
|
||||
assert.Greater(t, count, 0, "RecordAccess should be called for injected facts")
|
||||
}
|
||||
|
||||
func TestEngine_AccessCountTracking(t *testing.T) {
|
||||
fact := memory.NewFact("Architecture decision", memory.LevelProject, "arch", "")
|
||||
provider := newMockProvider(fact)
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
engine := New(cfg, provider)
|
||||
|
||||
// Build context 3 times
|
||||
for i := 0; i < 3; i++ {
|
||||
engine.BuildContext("test", map[string]interface{}{"content": "architecture"})
|
||||
}
|
||||
|
||||
// Internal access count should be tracked
|
||||
count := engine.GetAccessCount(fact.ID)
|
||||
assert.Equal(t, 3, count)
|
||||
}
|
||||
|
||||
func TestEngine_AccessCountInfluencesRanking(t *testing.T) {
|
||||
// Two similar facts but one has been accessed more
|
||||
fact1 := memory.NewFact("Architecture pattern A", memory.LevelDomain, "arch", "")
|
||||
fact2 := memory.NewFact("Architecture pattern B", memory.LevelDomain, "arch", "")
|
||||
|
||||
provider := newMockProvider(fact1, fact2)
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
cfg.FrequencyWeight = 0.9 // heavily weight frequency
|
||||
cfg.KeywordWeight = 0.01
|
||||
cfg.RecencyWeight = 0.01
|
||||
cfg.LevelWeight = 0.01
|
||||
engine := New(cfg, provider)
|
||||
|
||||
// Simulate fact1 being accessed many times
|
||||
for i := 0; i < 20; i++ {
|
||||
engine.recordAccessInternal(fact1.ID)
|
||||
}
|
||||
|
||||
frame := engine.BuildContext("test", map[string]interface{}{"content": "architecture pattern"})
|
||||
require.GreaterOrEqual(t, len(frame.Facts), 2)
|
||||
// fact1 should rank higher due to frequency
|
||||
assert.Equal(t, fact1.ID, frame.Facts[0].Fact.ID)
|
||||
}
|
||||
|
||||
// --- Middleware tests ---
|
||||
|
||||
func TestMiddleware_InjectsContext(t *testing.T) {
|
||||
fact := memory.NewFact("Always remember: TDD first", memory.LevelProject, "process", "")
|
||||
provider := newMockProvider(fact)
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
engine := New(cfg, provider)
|
||||
|
||||
// Create a simple handler
|
||||
handler := func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{Type: "text", Text: "Original result"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Wrap with middleware
|
||||
wrapped := engine.Middleware()(handler)
|
||||
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Name = "test_tool"
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"content": "TDD process",
|
||||
}
|
||||
|
||||
result, err := wrapped(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, result.Content, 1)
|
||||
|
||||
text := result.Content[0].(mcp.TextContent).Text
|
||||
assert.Contains(t, text, "Original result")
|
||||
assert.Contains(t, text, "[MEMORY CONTEXT]")
|
||||
assert.Contains(t, text, "TDD first")
|
||||
}
|
||||
|
||||
func TestMiddleware_DisabledPassesThrough(t *testing.T) {
|
||||
fact := memory.NewFact("should not appear", memory.LevelProject, "test", "")
|
||||
provider := newMockProvider(fact)
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
cfg.Enabled = false
|
||||
engine := New(cfg, provider)
|
||||
|
||||
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{Type: "text", Text: "Original only"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
wrapped := engine.Middleware()(handler)
|
||||
result, err := wrapped(context.Background(), mcp.CallToolRequest{})
|
||||
require.NoError(t, err)
|
||||
|
||||
text := result.Content[0].(mcp.TextContent).Text
|
||||
assert.Equal(t, "Original only", text)
|
||||
assert.NotContains(t, text, "[MEMORY CONTEXT]")
|
||||
}
|
||||
|
||||
func TestMiddleware_HandlerErrorPassedThrough(t *testing.T) {
|
||||
provider := newMockProvider()
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
engine := New(cfg, provider)
|
||||
|
||||
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return nil, fmt.Errorf("handler error")
|
||||
}
|
||||
|
||||
wrapped := engine.Middleware()(handler)
|
||||
result, err := wrapped(context.Background(), mcp.CallToolRequest{})
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestMiddleware_ErrorResult_NoInjection(t *testing.T) {
|
||||
fact := memory.NewFact("should not appear on errors", memory.LevelProject, "test", "")
|
||||
provider := newMockProvider(fact)
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
engine := New(cfg, provider)
|
||||
|
||||
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{Type: "text", Text: "Error: something failed"},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
wrapped := engine.Middleware()(handler)
|
||||
result, err := wrapped(context.Background(), mcp.CallToolRequest{})
|
||||
require.NoError(t, err)
|
||||
|
||||
text := result.Content[0].(mcp.TextContent).Text
|
||||
assert.NotContains(t, text, "[MEMORY CONTEXT]", "should not inject context on error results")
|
||||
}
|
||||
|
||||
func TestMiddleware_EmptyContentSlice(t *testing.T) {
|
||||
provider := newMockProvider(memory.NewFact("test", memory.LevelProject, "a", ""))
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
engine := New(cfg, provider)
|
||||
|
||||
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
wrapped := engine.Middleware()(handler)
|
||||
result, err := wrapped(context.Background(), mcp.CallToolRequest{})
|
||||
require.NoError(t, err)
|
||||
// Should handle empty content gracefully
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
func TestMiddleware_NilResult(t *testing.T) {
|
||||
provider := newMockProvider()
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
engine := New(cfg, provider)
|
||||
|
||||
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
wrapped := engine.Middleware()(handler)
|
||||
result, err := wrapped(context.Background(), mcp.CallToolRequest{})
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestMiddleware_SkipListTools(t *testing.T) {
|
||||
fact := memory.NewFact("Should not appear for skipped tools", memory.LevelProject, "test", "")
|
||||
provider := newMockProvider(fact)
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
engine := New(cfg, provider)
|
||||
|
||||
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{Type: "text", Text: "Facts result"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
wrapped := engine.Middleware()(handler)
|
||||
|
||||
// Tools in default skip list should NOT get context injected
|
||||
skipTools := []string{"search_facts", "get_fact", "get_l0_facts", "health", "version", "dashboard"}
|
||||
for _, tool := range skipTools {
|
||||
t.Run(tool, func(t *testing.T) {
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Name = tool
|
||||
req.Params.Arguments = map[string]interface{}{"query": "test"}
|
||||
|
||||
result, err := wrapped(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
text := result.Content[0].(mcp.TextContent).Text
|
||||
assert.NotContains(t, text, "[MEMORY CONTEXT]",
|
||||
"tool %s is in skip list, should not get context injected", tool)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_NonSkipToolGetsContext(t *testing.T) {
|
||||
fact := memory.NewFact("Important architecture fact", memory.LevelProject, "arch", "")
|
||||
provider := newMockProvider(fact)
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
engine := New(cfg, provider)
|
||||
|
||||
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{Type: "text", Text: "Tool result"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
wrapped := engine.Middleware()(handler)
|
||||
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Name = "add_causal_node"
|
||||
req.Params.Arguments = map[string]interface{}{"content": "architecture decision"}
|
||||
|
||||
result, err := wrapped(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
text := result.Content[0].(mcp.TextContent).Text
|
||||
assert.Contains(t, text, "[MEMORY CONTEXT]")
|
||||
}
|
||||
|
||||
// --- Concurrency test ---
|
||||
|
||||
func TestEngine_ConcurrentAccess(t *testing.T) {
|
||||
facts := make([]*memory.Fact, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
facts[i] = memory.NewFact(fmt.Sprintf("Concurrent fact %d", i), memory.LevelProject, "arch", "")
|
||||
}
|
||||
|
||||
provider := newMockProvider(facts...)
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
engine := New(cfg, provider)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 20; i++ {
|
||||
wg.Add(1)
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
engine.BuildContext("tool", map[string]interface{}{
|
||||
"content": fmt.Sprintf("query %d", n),
|
||||
})
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Just verify no panics or races (run with -race)
|
||||
for _, f := range facts {
|
||||
count := engine.GetAccessCount(f.ID)
|
||||
assert.GreaterOrEqual(t, count, 0)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Benchmark ---
|
||||
|
||||
func BenchmarkEngine_BuildContext(b *testing.B) {
|
||||
facts := make([]*memory.Fact, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
facts[i] = memory.NewFact(
|
||||
"Architecture uses clean layers with dependency injection for modularity",
|
||||
memory.HierLevel(i%4), "arch", "core",
|
||||
)
|
||||
}
|
||||
|
||||
provider := newMockProvider(facts...)
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
engine := New(cfg, provider)
|
||||
|
||||
args := map[string]interface{}{
|
||||
"content": "architecture clean layers dependency",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
engine.BuildContext("test_tool", args)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMiddleware(b *testing.B) {
|
||||
facts := make([]*memory.Fact, 50)
|
||||
for i := 0; i < 50; i++ {
|
||||
facts[i] = memory.NewFact("test fact content", memory.LevelProject, "arch", "")
|
||||
}
|
||||
|
||||
provider := newMockProvider(facts...)
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
engine := New(cfg, provider)
|
||||
|
||||
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{Type: "text", Text: "result"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
wrapped := engine.Middleware()(handler)
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{"content": "test"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = wrapped(context.Background(), req)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Mock InteractionLogger ---
|
||||
|
||||
type mockInteractionLogger struct {
|
||||
mu sync.Mutex
|
||||
entries []logEntry
|
||||
failErr error // if set, Record returns this error
|
||||
}
|
||||
|
||||
type logEntry struct {
|
||||
toolName string
|
||||
args map[string]interface{}
|
||||
}
|
||||
|
||||
func (m *mockInteractionLogger) Record(_ context.Context, toolName string, args map[string]interface{}) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.failErr != nil {
|
||||
return m.failErr
|
||||
}
|
||||
m.entries = append(m.entries, logEntry{toolName: toolName, args: args})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockInteractionLogger) count() int {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return len(m.entries)
|
||||
}
|
||||
|
||||
func (m *mockInteractionLogger) lastToolName() string {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if len(m.entries) == 0 {
|
||||
return ""
|
||||
}
|
||||
return m.entries[len(m.entries)-1].toolName
|
||||
}
|
||||
|
||||
// --- Interaction Logger Tests ---
|
||||
|
||||
func TestMiddleware_InteractionLogger_RecordsToolCalls(t *testing.T) {
|
||||
provider := newMockProvider()
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
engine := New(cfg, provider)
|
||||
|
||||
logger := &mockInteractionLogger{}
|
||||
engine.SetInteractionLogger(logger)
|
||||
|
||||
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{Type: "text", Text: "ok"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
wrapped := engine.Middleware()(handler)
|
||||
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Name = "add_fact"
|
||||
req.Params.Arguments = map[string]interface{}{"content": "test fact"}
|
||||
|
||||
_, err := wrapped(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 1, logger.count())
|
||||
assert.Equal(t, "add_fact", logger.lastToolName())
|
||||
}
|
||||
|
||||
func TestMiddleware_InteractionLogger_RecordsSkippedTools(t *testing.T) {
|
||||
// Even skip-list tools should be recorded in the interaction log
|
||||
// (skip-list only controls context injection, not logging)
|
||||
provider := newMockProvider()
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
engine := New(cfg, provider)
|
||||
|
||||
logger := &mockInteractionLogger{}
|
||||
engine.SetInteractionLogger(logger)
|
||||
|
||||
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{Type: "text", Text: "ok"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
wrapped := engine.Middleware()(handler)
|
||||
|
||||
// "health" is in the skip list
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Name = "health"
|
||||
|
||||
_, err := wrapped(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 1, logger.count(), "skip-list tools should still be logged")
|
||||
assert.Equal(t, "health", logger.lastToolName())
|
||||
}
|
||||
|
||||
func TestMiddleware_InteractionLogger_ErrorDoesNotBreakHandler(t *testing.T) {
|
||||
// Logger errors must be swallowed — never break the tool call
|
||||
provider := newMockProvider()
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
engine := New(cfg, provider)
|
||||
|
||||
logger := &mockInteractionLogger{failErr: fmt.Errorf("disk full")}
|
||||
engine.SetInteractionLogger(logger)
|
||||
|
||||
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{Type: "text", Text: "handler succeeded"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
wrapped := engine.Middleware()(handler)
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Name = "add_fact"
|
||||
|
||||
result, err := wrapped(context.Background(), req)
|
||||
require.NoError(t, err, "logger error must not propagate")
|
||||
require.NotNil(t, result)
|
||||
text := result.Content[0].(mcp.TextContent).Text
|
||||
assert.Contains(t, text, "handler succeeded")
|
||||
}
|
||||
|
||||
func TestMiddleware_NoLogger_StillWorks(t *testing.T) {
|
||||
// Without a logger set, middleware should work normally
|
||||
provider := newMockProvider()
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
engine := New(cfg, provider)
|
||||
// engine.logger is nil — no SetInteractionLogger call
|
||||
|
||||
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{Type: "text", Text: "no logger ok"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
wrapped := engine.Middleware()(handler)
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Name = "add_fact"
|
||||
|
||||
result, err := wrapped(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
}
|
||||
|
||||
func TestMiddleware_InteractionLogger_MultipleToolCalls(t *testing.T) {
|
||||
provider := newMockProvider()
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
engine := New(cfg, provider)
|
||||
|
||||
logger := &mockInteractionLogger{}
|
||||
engine.SetInteractionLogger(logger)
|
||||
|
||||
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{Type: "text", Text: "ok"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
wrapped := engine.Middleware()(handler)
|
||||
|
||||
toolNames := []string{"add_fact", "search_facts", "health", "add_causal_node", "version"}
|
||||
for _, name := range toolNames {
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Name = name
|
||||
_, _ = wrapped(context.Background(), req)
|
||||
}
|
||||
|
||||
assert.Equal(t, 5, logger.count(), "all 5 tool calls should be logged")
|
||||
}
|
||||
|
||||
func TestMiddleware_InteractionLogger_ConcurrentCalls(t *testing.T) {
|
||||
provider := newMockProvider()
|
||||
cfg := ctxdomain.DefaultEngineConfig()
|
||||
engine := New(cfg, provider)
|
||||
|
||||
logger := &mockInteractionLogger{}
|
||||
engine.SetInteractionLogger(logger)
|
||||
|
||||
handler := func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{Type: "text", Text: "ok"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
wrapped := engine.Middleware()(handler)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 20; i++ {
|
||||
wg.Add(1)
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Name = fmt.Sprintf("tool_%d", n)
|
||||
_, _ = wrapped(context.Background(), req)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
assert.Equal(t, 20, logger.count(), "all 20 concurrent calls should be logged")
|
||||
}
|
||||
245
internal/application/contextengine/processor.go
Normal file
245
internal/application/contextengine/processor.go
Normal file
|
|
@ -0,0 +1,245 @@
|
|||
// Package contextengine — processor.go
|
||||
// Processes unprocessed interaction log entries into session summary facts.
|
||||
// This closes the memory loop: tool calls → interaction log → summary facts → boot instructions.
|
||||
package contextengine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sentinel-community/gomcp/internal/domain/memory"
|
||||
"github.com/sentinel-community/gomcp/internal/infrastructure/sqlite"
|
||||
)
|
||||
|
||||
// InteractionProcessor processes unprocessed interaction log entries
|
||||
// and creates session summary facts from them.
|
||||
type InteractionProcessor struct {
|
||||
repo *sqlite.InteractionLogRepo
|
||||
factStore memory.FactStore
|
||||
}
|
||||
|
||||
// NewInteractionProcessor creates a new processor.
|
||||
func NewInteractionProcessor(repo *sqlite.InteractionLogRepo, store memory.FactStore) *InteractionProcessor {
|
||||
return &InteractionProcessor{repo: repo, factStore: store}
|
||||
}
|
||||
|
||||
// ProcessStartup processes unprocessed entries from a previous (possibly crashed) session.
|
||||
// It creates an L1 "session summary" fact and marks all entries as processed.
|
||||
// Returns the summary text (empty if nothing to process).
|
||||
func (p *InteractionProcessor) ProcessStartup(ctx context.Context) (string, error) {
|
||||
entries, err := p.repo.GetUnprocessed(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get unprocessed: %w", err)
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
summary := buildSessionSummary(entries, "previous session (recovered)")
|
||||
if summary == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Save as L1 fact (domain-level, not project-level)
|
||||
fact := memory.NewFact(summary, memory.LevelDomain, "session-history", "interaction-processor")
|
||||
fact.Source = "auto:interaction-processor"
|
||||
if err := p.factStore.Add(ctx, fact); err != nil {
|
||||
return "", fmt.Errorf("save session summary fact: %w", err)
|
||||
}
|
||||
|
||||
// Mark all as processed
|
||||
ids := make([]int64, len(entries))
|
||||
for i, e := range entries {
|
||||
ids[i] = e.ID
|
||||
}
|
||||
if err := p.repo.MarkProcessed(ctx, ids); err != nil {
|
||||
return "", fmt.Errorf("mark processed: %w", err)
|
||||
}
|
||||
|
||||
return summary, nil
|
||||
}
|
||||
|
||||
// ProcessShutdown processes entries from the current session at graceful shutdown.
|
||||
// Similar to ProcessStartup but labels differently.
|
||||
func (p *InteractionProcessor) ProcessShutdown(ctx context.Context) (string, error) {
|
||||
entries, err := p.repo.GetUnprocessed(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get unprocessed: %w", err)
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
summary := buildSessionSummary(entries, "session ending "+time.Now().Format("2006-01-02 15:04"))
|
||||
if summary == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
fact := memory.NewFact(summary, memory.LevelDomain, "session-history", "interaction-processor")
|
||||
fact.Source = "auto:session-shutdown"
|
||||
if err := p.factStore.Add(ctx, fact); err != nil {
|
||||
return "", fmt.Errorf("save session summary fact: %w", err)
|
||||
}
|
||||
|
||||
ids := make([]int64, len(entries))
|
||||
for i, e := range entries {
|
||||
ids[i] = e.ID
|
||||
}
|
||||
if err := p.repo.MarkProcessed(ctx, ids); err != nil {
|
||||
return "", fmt.Errorf("mark processed: %w", err)
|
||||
}
|
||||
|
||||
return summary, nil
|
||||
}
|
||||
|
||||
// buildSessionSummary creates a compact text summary from interaction log entries.
|
||||
func buildSessionSummary(entries []sqlite.InteractionEntry, label string) string {
|
||||
if len(entries) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Count tool calls
|
||||
toolCounts := make(map[string]int)
|
||||
for _, e := range entries {
|
||||
toolCounts[e.ToolName]++
|
||||
}
|
||||
|
||||
// Sort by count descending
|
||||
type toolStat struct {
|
||||
name string
|
||||
count int
|
||||
}
|
||||
stats := make([]toolStat, 0, len(toolCounts))
|
||||
for name, count := range toolCounts {
|
||||
stats = append(stats, toolStat{name, count})
|
||||
}
|
||||
sort.Slice(stats, func(i, j int) bool { return stats[i].count > stats[j].count })
|
||||
|
||||
// Extract topics from args (unique string values)
|
||||
topics := extractTopicsFromEntries(entries)
|
||||
|
||||
// Time range
|
||||
var earliest, latest time.Time
|
||||
for _, e := range entries {
|
||||
if earliest.IsZero() || e.Timestamp.Before(earliest) {
|
||||
earliest = e.Timestamp
|
||||
}
|
||||
if latest.IsZero() || e.Timestamp.After(latest) {
|
||||
latest = e.Timestamp
|
||||
}
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("Session summary (%s): %d tool calls", label, len(entries)))
|
||||
if !earliest.IsZero() {
|
||||
duration := latest.Sub(earliest)
|
||||
if duration > 0 {
|
||||
b.WriteString(fmt.Sprintf(" over %s", formatDuration(duration)))
|
||||
}
|
||||
}
|
||||
b.WriteString(". ")
|
||||
|
||||
// Top tools used
|
||||
b.WriteString("Tools used: ")
|
||||
for i, ts := range stats {
|
||||
if i >= 8 {
|
||||
b.WriteString(fmt.Sprintf(" +%d more", len(stats)-8))
|
||||
break
|
||||
}
|
||||
if i > 0 {
|
||||
b.WriteString(", ")
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("%s(%d)", ts.name, ts.count))
|
||||
}
|
||||
b.WriteString(". ")
|
||||
|
||||
// Topics
|
||||
if len(topics) > 0 {
|
||||
b.WriteString("Topics: ")
|
||||
limit := 10
|
||||
if len(topics) < limit {
|
||||
limit = len(topics)
|
||||
}
|
||||
b.WriteString(strings.Join(topics[:limit], ", "))
|
||||
if len(topics) > limit {
|
||||
b.WriteString(fmt.Sprintf(" +%d more", len(topics)-limit))
|
||||
}
|
||||
b.WriteString(".")
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// extractTopicsFromEntries pulls unique meaningful strings from tool arguments.
|
||||
func extractTopicsFromEntries(entries []sqlite.InteractionEntry) []string {
|
||||
seen := make(map[string]bool)
|
||||
var topics []string
|
||||
|
||||
for _, e := range entries {
|
||||
if e.ArgsJSON == "" {
|
||||
continue
|
||||
}
|
||||
// Simple extraction: find quoted strings in JSON args
|
||||
// ArgsJSON looks like {"query":"architecture","content":"some fact"}
|
||||
parts := strings.Split(e.ArgsJSON, "\"")
|
||||
for i := 3; i < len(parts); i += 4 {
|
||||
// Values are at odd positions after the key
|
||||
val := parts[i]
|
||||
if len(val) < 3 || len(val) > 100 {
|
||||
continue
|
||||
}
|
||||
// Skip common non-topic values
|
||||
lower := strings.ToLower(val)
|
||||
if lower == "true" || lower == "false" || lower == "null" || lower == "" {
|
||||
continue
|
||||
}
|
||||
if !seen[lower] {
|
||||
seen[lower] = true
|
||||
topics = append(topics, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return topics
|
||||
}
|
||||
|
||||
// formatDuration formats a duration into a human-readable string.
|
||||
func formatDuration(d time.Duration) string {
|
||||
if d < time.Minute {
|
||||
return fmt.Sprintf("%ds", int(d.Seconds()))
|
||||
}
|
||||
if d < time.Hour {
|
||||
return fmt.Sprintf("%dm", int(d.Minutes()))
|
||||
}
|
||||
return fmt.Sprintf("%dh%dm", int(d.Hours()), int(d.Minutes())%60)
|
||||
}
|
||||
|
||||
// GetLastSessionSummary searches the fact store for the most recent session summary.
|
||||
func GetLastSessionSummary(ctx context.Context, store memory.FactStore) string {
|
||||
facts, err := store.Search(ctx, "Session summary", 5)
|
||||
if err != nil || len(facts) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Find the most recent one from session-history domain
|
||||
var best *memory.Fact
|
||||
for _, f := range facts {
|
||||
if f.Domain != "session-history" {
|
||||
continue
|
||||
}
|
||||
if f.IsStale || f.IsArchived {
|
||||
continue
|
||||
}
|
||||
if best == nil || f.CreatedAt.After(best.CreatedAt) {
|
||||
best = f
|
||||
}
|
||||
}
|
||||
|
||||
if best == nil {
|
||||
return ""
|
||||
}
|
||||
return best.Content
|
||||
}
|
||||
251
internal/application/contextengine/processor_test.go
Normal file
251
internal/application/contextengine/processor_test.go
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
package contextengine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sentinel-community/gomcp/internal/domain/memory"
|
||||
"github.com/sentinel-community/gomcp/internal/infrastructure/sqlite"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- mock FactStore for processor tests ---
|
||||
|
||||
type procMockFactStore struct {
|
||||
facts []*memory.Fact
|
||||
}
|
||||
|
||||
func (m *procMockFactStore) Add(_ context.Context, f *memory.Fact) error {
|
||||
m.facts = append(m.facts, f)
|
||||
return nil
|
||||
}
|
||||
func (m *procMockFactStore) Get(_ context.Context, id string) (*memory.Fact, error) {
|
||||
for _, f := range m.facts {
|
||||
if f.ID == id {
|
||||
return f, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (m *procMockFactStore) Update(_ context.Context, _ *memory.Fact) error { return nil }
|
||||
func (m *procMockFactStore) Delete(_ context.Context, _ string) error { return nil }
|
||||
func (m *procMockFactStore) ListByDomain(_ context.Context, _ string, _ bool) ([]*memory.Fact, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *procMockFactStore) ListByLevel(_ context.Context, _ memory.HierLevel) ([]*memory.Fact, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *procMockFactStore) ListDomains(_ context.Context) ([]string, error) { return nil, nil }
|
||||
func (m *procMockFactStore) GetStale(_ context.Context, _ bool) ([]*memory.Fact, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *procMockFactStore) Search(_ context.Context, query string, limit int) ([]*memory.Fact, error) {
|
||||
var results []*memory.Fact
|
||||
for _, f := range m.facts {
|
||||
if len(results) >= limit {
|
||||
break
|
||||
}
|
||||
if contains(f.Content, query) {
|
||||
results = append(results, f)
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
func (m *procMockFactStore) GetExpired(_ context.Context) ([]*memory.Fact, error) { return nil, nil }
|
||||
func (m *procMockFactStore) RefreshTTL(_ context.Context, _ string) error { return nil }
|
||||
func (m *procMockFactStore) TouchFact(_ context.Context, _ string) error { return nil }
|
||||
func (m *procMockFactStore) GetColdFacts(_ context.Context, _ int) ([]*memory.Fact, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *procMockFactStore) CompressFacts(_ context.Context, _ []string, _ string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
func (m *procMockFactStore) Stats(_ context.Context) (*memory.FactStoreStats, error) {
|
||||
return &memory.FactStoreStats{}, nil
|
||||
}
|
||||
func (m *procMockFactStore) ListGenes(_ context.Context) ([]*memory.Fact, error) { return nil, nil }
|
||||
|
||||
func contains(s, sub string) bool {
|
||||
return len(s) >= len(sub) && (s == sub || len(sub) == 0 ||
|
||||
(len(s) > 0 && len(sub) > 0 && containsStr(s, sub)))
|
||||
}
|
||||
func containsStr(s, sub string) bool {
|
||||
for i := 0; i <= len(s)-len(sub); i++ {
|
||||
if s[i:i+len(sub)] == sub {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// --- Tests ---
|
||||
|
||||
func TestInteractionProcessor_ProcessStartup_NoEntries(t *testing.T) {
|
||||
db, err := sqlite.OpenMemory()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
repo, err := sqlite.NewInteractionLogRepo(db)
|
||||
require.NoError(t, err)
|
||||
|
||||
store := &procMockFactStore{}
|
||||
proc := NewInteractionProcessor(repo, store)
|
||||
|
||||
summary, err := proc.ProcessStartup(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, summary)
|
||||
assert.Empty(t, store.facts, "no facts should be created")
|
||||
}
|
||||
|
||||
func TestInteractionProcessor_ProcessStartup_CreatesFactAndMarksProcessed(t *testing.T) {
|
||||
db, err := sqlite.OpenMemory()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
repo, err := sqlite.NewInteractionLogRepo(db)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Insert some tool calls
|
||||
require.NoError(t, repo.Record(ctx, "add_fact", map[string]interface{}{"content": "test fact about architecture"}))
|
||||
require.NoError(t, repo.Record(ctx, "search_facts", map[string]interface{}{"query": "security"}))
|
||||
require.NoError(t, repo.Record(ctx, "health", nil))
|
||||
require.NoError(t, repo.Record(ctx, "add_fact", map[string]interface{}{"content": "another fact"}))
|
||||
require.NoError(t, repo.Record(ctx, "dashboard", nil))
|
||||
|
||||
store := &procMockFactStore{}
|
||||
proc := NewInteractionProcessor(repo, store)
|
||||
|
||||
summary, err := proc.ProcessStartup(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, summary)
|
||||
assert.Contains(t, summary, "Session summary")
|
||||
assert.Contains(t, summary, "5 tool calls")
|
||||
assert.Contains(t, summary, "add_fact(2)")
|
||||
|
||||
// Fact should be saved
|
||||
require.Len(t, store.facts, 1)
|
||||
assert.Equal(t, memory.LevelDomain, store.facts[0].Level)
|
||||
assert.Equal(t, "session-history", store.facts[0].Domain)
|
||||
assert.Equal(t, "auto:interaction-processor", store.facts[0].Source)
|
||||
|
||||
// All entries should be marked processed
|
||||
_, unprocessed, err := repo.Count(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, unprocessed)
|
||||
}
|
||||
|
||||
func TestInteractionProcessor_ProcessShutdown(t *testing.T) {
|
||||
db, err := sqlite.OpenMemory()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
repo, err := sqlite.NewInteractionLogRepo(db)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
require.NoError(t, repo.Record(ctx, "version", nil))
|
||||
require.NoError(t, repo.Record(ctx, "search_facts", map[string]interface{}{"query": "gomcp"}))
|
||||
|
||||
store := &procMockFactStore{}
|
||||
proc := NewInteractionProcessor(repo, store)
|
||||
|
||||
summary, err := proc.ProcessShutdown(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, summary, "session ending")
|
||||
assert.Contains(t, summary, "2 tool calls")
|
||||
|
||||
require.Len(t, store.facts, 1)
|
||||
assert.Equal(t, "auto:session-shutdown", store.facts[0].Source)
|
||||
}
|
||||
|
||||
func TestBuildSessionSummary_ToolCounts(t *testing.T) {
|
||||
now := time.Now()
|
||||
entries := []sqlite.InteractionEntry{
|
||||
{ID: 1, ToolName: "add_fact", Timestamp: now},
|
||||
{ID: 2, ToolName: "add_fact", Timestamp: now},
|
||||
{ID: 3, ToolName: "add_fact", Timestamp: now},
|
||||
{ID: 4, ToolName: "search_facts", Timestamp: now},
|
||||
{ID: 5, ToolName: "health", Timestamp: now},
|
||||
}
|
||||
|
||||
summary := buildSessionSummary(entries, "test")
|
||||
assert.Contains(t, summary, "5 tool calls")
|
||||
assert.Contains(t, summary, "add_fact(3)")
|
||||
assert.Contains(t, summary, "search_facts(1)")
|
||||
assert.Contains(t, summary, "health(1)")
|
||||
}
|
||||
|
||||
func TestBuildSessionSummary_Duration(t *testing.T) {
|
||||
now := time.Now()
|
||||
entries := []sqlite.InteractionEntry{
|
||||
{ID: 1, ToolName: "a", Timestamp: now.Add(-30 * time.Minute)},
|
||||
{ID: 2, ToolName: "b", Timestamp: now},
|
||||
}
|
||||
|
||||
summary := buildSessionSummary(entries, "test")
|
||||
assert.Contains(t, summary, "30m")
|
||||
}
|
||||
|
||||
func TestBuildSessionSummary_Empty(t *testing.T) {
|
||||
summary := buildSessionSummary(nil, "test")
|
||||
assert.Empty(t, summary)
|
||||
}
|
||||
|
||||
func TestBuildSessionSummary_Topics(t *testing.T) {
|
||||
now := time.Now()
|
||||
entries := []sqlite.InteractionEntry{
|
||||
{ID: 1, ToolName: "search_facts", ArgsJSON: `{"query":"architecture"}`, Timestamp: now},
|
||||
{ID: 2, ToolName: "add_fact", ArgsJSON: `{"content":"security review"}`, Timestamp: now},
|
||||
}
|
||||
|
||||
summary := buildSessionSummary(entries, "test")
|
||||
assert.Contains(t, summary, "Topics:")
|
||||
assert.Contains(t, summary, "architecture")
|
||||
}
|
||||
|
||||
func TestGetLastSessionSummary_Found(t *testing.T) {
|
||||
store := &procMockFactStore{}
|
||||
f := memory.NewFact("Session summary (test): 5 tool calls", memory.LevelDomain, "session-history", "")
|
||||
f.Source = "auto:session-shutdown"
|
||||
store.facts = append(store.facts, f)
|
||||
|
||||
result := GetLastSessionSummary(context.Background(), store)
|
||||
assert.Contains(t, result, "Session summary")
|
||||
}
|
||||
|
||||
func TestGetLastSessionSummary_NotFound(t *testing.T) {
|
||||
store := &procMockFactStore{}
|
||||
result := GetLastSessionSummary(context.Background(), store)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestGetLastSessionSummary_SkipsStale(t *testing.T) {
|
||||
store := &procMockFactStore{}
|
||||
f := memory.NewFact("Session summary (test): old", memory.LevelDomain, "session-history", "")
|
||||
f.IsStale = true
|
||||
store.facts = append(store.facts, f)
|
||||
|
||||
result := GetLastSessionSummary(context.Background(), store)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestFormatDuration(t *testing.T) {
|
||||
assert.Equal(t, "30s", formatDuration(30*time.Second))
|
||||
assert.Equal(t, "5m", formatDuration(5*time.Minute))
|
||||
assert.Equal(t, "2h15m", formatDuration(2*time.Hour+15*time.Minute))
|
||||
}
|
||||
|
||||
func TestExtractTopicsFromEntries(t *testing.T) {
|
||||
entries := []sqlite.InteractionEntry{
|
||||
{ArgsJSON: `{"query":"architecture"}`},
|
||||
{ArgsJSON: `{"content":"security review"}`},
|
||||
{ArgsJSON: ``}, // empty
|
||||
}
|
||||
|
||||
topics := extractTopicsFromEntries(entries)
|
||||
assert.NotEmpty(t, topics)
|
||||
}
|
||||
117
internal/application/contextengine/provider.go
Normal file
117
internal/application/contextengine/provider.go
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
package contextengine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/sentinel-community/gomcp/internal/domain/memory"
|
||||
|
||||
ctxdomain "github.com/sentinel-community/gomcp/internal/domain/context"
|
||||
)
|
||||
|
||||
// StoreFactProvider adapts FactStore + HotCache to the FactProvider interface,
|
||||
// bridging infrastructure storage with the context engine domain.
|
||||
type StoreFactProvider struct {
|
||||
store memory.FactStore
|
||||
cache memory.HotCache
|
||||
|
||||
mu sync.Mutex
|
||||
accessCounts map[string]int
|
||||
}
|
||||
|
||||
// NewStoreFactProvider creates a FactProvider backed by FactStore and optional HotCache.
|
||||
func NewStoreFactProvider(store memory.FactStore, cache memory.HotCache) *StoreFactProvider {
|
||||
return &StoreFactProvider{
|
||||
store: store,
|
||||
cache: cache,
|
||||
accessCounts: make(map[string]int),
|
||||
}
|
||||
}
|
||||
|
||||
// Verify interface compliance at compile time.
|
||||
var _ ctxdomain.FactProvider = (*StoreFactProvider)(nil)
|
||||
|
||||
// GetRelevantFacts returns candidate facts for context injection.
|
||||
// Uses keyword search from tool arguments + L0 facts as candidates.
|
||||
func (p *StoreFactProvider) GetRelevantFacts(args map[string]interface{}) ([]*memory.Fact, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Always include L0 facts
|
||||
l0Facts, err := p.GetL0Facts()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Extract query text from arguments for search
|
||||
query := extractQueryFromArgs(args)
|
||||
if query == "" {
|
||||
return l0Facts, nil
|
||||
}
|
||||
|
||||
// Search for additional relevant facts
|
||||
searchResults, err := p.store.Search(ctx, query, 30)
|
||||
if err != nil {
|
||||
// Degrade gracefully — just return L0 facts
|
||||
return l0Facts, nil
|
||||
}
|
||||
|
||||
// Merge L0 + search results, deduplicating by ID
|
||||
seen := make(map[string]bool, len(l0Facts))
|
||||
merged := make([]*memory.Fact, 0, len(l0Facts)+len(searchResults))
|
||||
|
||||
for _, f := range l0Facts {
|
||||
seen[f.ID] = true
|
||||
merged = append(merged, f)
|
||||
}
|
||||
for _, f := range searchResults {
|
||||
if !seen[f.ID] {
|
||||
seen[f.ID] = true
|
||||
merged = append(merged, f)
|
||||
}
|
||||
}
|
||||
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
// GetL0Facts returns all L0 (project-level) facts.
|
||||
// Uses HotCache if available, falls back to store.
|
||||
func (p *StoreFactProvider) GetL0Facts() ([]*memory.Fact, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
if p.cache != nil {
|
||||
facts, err := p.cache.GetL0Facts(ctx)
|
||||
if err == nil && len(facts) > 0 {
|
||||
return facts, nil
|
||||
}
|
||||
}
|
||||
|
||||
return p.store.ListByLevel(ctx, memory.LevelProject)
|
||||
}
|
||||
|
||||
// RecordAccess increments the access counter for a fact.
|
||||
func (p *StoreFactProvider) RecordAccess(factID string) {
|
||||
p.mu.Lock()
|
||||
p.accessCounts[factID]++
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
// extractQueryFromArgs builds a search query string from argument values.
|
||||
func extractQueryFromArgs(args map[string]interface{}) string {
|
||||
var parts []string
|
||||
for _, v := range args {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
parts = append(parts, s)
|
||||
}
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
result := ""
|
||||
for i, p := range parts {
|
||||
if i > 0 {
|
||||
result += " "
|
||||
}
|
||||
result += p
|
||||
}
|
||||
return result
|
||||
}
|
||||
278
internal/application/contextengine/provider_test.go
Normal file
278
internal/application/contextengine/provider_test.go
Normal file
|
|
@ -0,0 +1,278 @@
|
|||
package contextengine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/sentinel-community/gomcp/internal/domain/memory"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- Mock FactStore for provider tests ---
|
||||
|
||||
type mockFactStore struct {
|
||||
mu sync.Mutex
|
||||
facts map[string]*memory.Fact
|
||||
searchFacts []*memory.Fact
|
||||
searchErr error
|
||||
levelFacts map[memory.HierLevel][]*memory.Fact
|
||||
}
|
||||
|
||||
func newMockFactStore() *mockFactStore {
|
||||
return &mockFactStore{
|
||||
facts: make(map[string]*memory.Fact),
|
||||
levelFacts: make(map[memory.HierLevel][]*memory.Fact),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockFactStore) Add(_ context.Context, fact *memory.Fact) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.facts[fact.ID] = fact
|
||||
m.levelFacts[fact.Level] = append(m.levelFacts[fact.Level], fact)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockFactStore) Get(_ context.Context, id string) (*memory.Fact, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
f, ok := m.facts[id]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("not found")
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (m *mockFactStore) Update(_ context.Context, fact *memory.Fact) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.facts[fact.ID] = fact
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockFactStore) Delete(_ context.Context, id string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.facts, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockFactStore) ListByDomain(_ context.Context, _ string, _ bool) ([]*memory.Fact, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockFactStore) ListByLevel(_ context.Context, level memory.HierLevel) ([]*memory.Fact, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.levelFacts[level], nil
|
||||
}
|
||||
|
||||
func (m *mockFactStore) ListDomains(_ context.Context) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockFactStore) GetStale(_ context.Context, _ bool) ([]*memory.Fact, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockFactStore) Search(_ context.Context, _ string, _ int) ([]*memory.Fact, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.searchFacts, m.searchErr
|
||||
}
|
||||
|
||||
func (m *mockFactStore) GetExpired(_ context.Context) ([]*memory.Fact, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockFactStore) RefreshTTL(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockFactStore) TouchFact(_ context.Context, _ string) error { return nil }
|
||||
func (m *mockFactStore) GetColdFacts(_ context.Context, _ int) ([]*memory.Fact, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockFactStore) CompressFacts(_ context.Context, _ []string, _ string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *mockFactStore) Stats(_ context.Context) (*memory.FactStoreStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockFactStore) ListGenes(_ context.Context) ([]*memory.Fact, error) { return nil, nil }
|
||||
|
||||
// --- Mock HotCache ---
|
||||
|
||||
type mockHotCache struct {
|
||||
l0Facts []*memory.Fact
|
||||
l0Err error
|
||||
}
|
||||
|
||||
func (m *mockHotCache) GetL0Facts(_ context.Context) ([]*memory.Fact, error) {
|
||||
return m.l0Facts, m.l0Err
|
||||
}
|
||||
|
||||
func (m *mockHotCache) InvalidateFact(_ context.Context, _ string) error { return nil }
|
||||
func (m *mockHotCache) WarmUp(_ context.Context, _ []*memory.Fact) error { return nil }
|
||||
func (m *mockHotCache) Close() error { return nil }
|
||||
|
||||
// --- StoreFactProvider tests ---
|
||||
|
||||
func TestNewStoreFactProvider(t *testing.T) {
|
||||
store := newMockFactStore()
|
||||
provider := NewStoreFactProvider(store, nil)
|
||||
require.NotNil(t, provider)
|
||||
}
|
||||
|
||||
func TestStoreFactProvider_GetL0Facts_FromStore(t *testing.T) {
|
||||
store := newMockFactStore()
|
||||
f1 := memory.NewFact("L0 fact A", memory.LevelProject, "arch", "")
|
||||
f2 := memory.NewFact("L0 fact B", memory.LevelProject, "process", "")
|
||||
_ = store.Add(context.Background(), f1)
|
||||
_ = store.Add(context.Background(), f2)
|
||||
|
||||
provider := NewStoreFactProvider(store, nil) // no cache
|
||||
|
||||
facts, err := provider.GetL0Facts()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, facts, 2)
|
||||
}
|
||||
|
||||
func TestStoreFactProvider_GetL0Facts_FromCache(t *testing.T) {
|
||||
store := newMockFactStore()
|
||||
cacheFact := memory.NewFact("Cached L0", memory.LevelProject, "arch", "")
|
||||
cache := &mockHotCache{l0Facts: []*memory.Fact{cacheFact}}
|
||||
|
||||
provider := NewStoreFactProvider(store, cache)
|
||||
|
||||
facts, err := provider.GetL0Facts()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, facts, 1)
|
||||
assert.Equal(t, "Cached L0", facts[0].Content)
|
||||
}
|
||||
|
||||
func TestStoreFactProvider_GetL0Facts_CacheFallbackToStore(t *testing.T) {
|
||||
store := newMockFactStore()
|
||||
storeFact := memory.NewFact("Store L0", memory.LevelProject, "arch", "")
|
||||
_ = store.Add(context.Background(), storeFact)
|
||||
|
||||
cache := &mockHotCache{l0Facts: nil, l0Err: fmt.Errorf("cache miss")}
|
||||
provider := NewStoreFactProvider(store, cache)
|
||||
|
||||
facts, err := provider.GetL0Facts()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, facts, 1)
|
||||
assert.Equal(t, "Store L0", facts[0].Content)
|
||||
}
|
||||
|
||||
func TestStoreFactProvider_GetRelevantFacts_NoQuery(t *testing.T) {
|
||||
store := newMockFactStore()
|
||||
l0 := memory.NewFact("L0 always included", memory.LevelProject, "arch", "")
|
||||
_ = store.Add(context.Background(), l0)
|
||||
|
||||
provider := NewStoreFactProvider(store, nil)
|
||||
|
||||
// No string args → no query → only L0
|
||||
facts, err := provider.GetRelevantFacts(map[string]interface{}{
|
||||
"level": 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, facts, 1)
|
||||
}
|
||||
|
||||
func TestStoreFactProvider_GetRelevantFacts_WithSearch(t *testing.T) {
|
||||
store := newMockFactStore()
|
||||
l0 := memory.NewFact("L0 architecture", memory.LevelProject, "arch", "")
|
||||
searchResult := memory.NewFact("Found by search", memory.LevelDomain, "auth", "")
|
||||
_ = store.Add(context.Background(), l0)
|
||||
store.searchFacts = []*memory.Fact{searchResult}
|
||||
|
||||
provider := NewStoreFactProvider(store, nil)
|
||||
|
||||
facts, err := provider.GetRelevantFacts(map[string]interface{}{
|
||||
"content": "authentication module",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, facts, 2) // L0 + search result
|
||||
}
|
||||
|
||||
func TestStoreFactProvider_GetRelevantFacts_Deduplication(t *testing.T) {
|
||||
store := newMockFactStore()
|
||||
l0 := memory.NewFact("L0 architecture", memory.LevelProject, "arch", "")
|
||||
_ = store.Add(context.Background(), l0)
|
||||
// Search returns the same fact that's also L0
|
||||
store.searchFacts = []*memory.Fact{l0}
|
||||
|
||||
provider := NewStoreFactProvider(store, nil)
|
||||
|
||||
facts, err := provider.GetRelevantFacts(map[string]interface{}{
|
||||
"content": "architecture",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, facts, 1, "duplicate should be removed")
|
||||
}
|
||||
|
||||
func TestStoreFactProvider_GetRelevantFacts_SearchError_GracefulDegradation(t *testing.T) {
|
||||
store := newMockFactStore()
|
||||
l0 := memory.NewFact("L0 fact", memory.LevelProject, "arch", "")
|
||||
_ = store.Add(context.Background(), l0)
|
||||
store.searchErr = fmt.Errorf("search broken")
|
||||
|
||||
provider := NewStoreFactProvider(store, nil)
|
||||
|
||||
facts, err := provider.GetRelevantFacts(map[string]interface{}{
|
||||
"content": "test query",
|
||||
})
|
||||
require.NoError(t, err, "should degrade gracefully, not error")
|
||||
assert.Len(t, facts, 1, "should still return L0 facts")
|
||||
}
|
||||
|
||||
func TestStoreFactProvider_RecordAccess(t *testing.T) {
|
||||
store := newMockFactStore()
|
||||
provider := NewStoreFactProvider(store, nil)
|
||||
|
||||
provider.RecordAccess("fact-1")
|
||||
provider.RecordAccess("fact-1")
|
||||
provider.RecordAccess("fact-2")
|
||||
|
||||
provider.mu.Lock()
|
||||
assert.Equal(t, 2, provider.accessCounts["fact-1"])
|
||||
assert.Equal(t, 1, provider.accessCounts["fact-2"])
|
||||
provider.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestExtractQueryFromArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args map[string]interface{}
|
||||
want string
|
||||
}{
|
||||
{"nil", nil, ""},
|
||||
{"empty", map[string]interface{}{}, ""},
|
||||
{"no strings", map[string]interface{}{"level": 0, "flag": true}, ""},
|
||||
{"single string", map[string]interface{}{"content": "hello"}, "hello"},
|
||||
{"empty string", map[string]interface{}{"content": ""}, ""},
|
||||
{"mixed", map[string]interface{}{"content": "hello", "level": 0, "domain": "arch"}, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractQueryFromArgs(tt.args)
|
||||
if tt.name == "mixed" {
|
||||
// Map iteration order is non-deterministic, just check non-empty
|
||||
assert.NotEmpty(t, got)
|
||||
} else {
|
||||
if tt.want == "" {
|
||||
assert.Empty(t, got)
|
||||
} else {
|
||||
assert.Contains(t, got, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue