gomcp/internal/infrastructure/sqlite/causal_repo_test.go

141 lines
4.4 KiB
Go

// Copyright 2026 Syntrex Lab. All rights reserved.
// Use of this source code is governed by an Apache-2.0 license
// that can be found in the LICENSE file.
package sqlite
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/syntrex-lab/gomcp/internal/domain/causal"
)
func newTestCausalRepo(t *testing.T) *CausalRepo {
t.Helper()
db, err := OpenMemory()
require.NoError(t, err)
t.Cleanup(func() { db.Close() })
repo, err := NewCausalRepo(db)
require.NoError(t, err)
return repo
}
func TestCausalRepo_AddNode_GetChain(t *testing.T) {
repo := newTestCausalRepo(t)
ctx := context.Background()
decision := causal.NewNode(causal.NodeDecision, "Use SQLite")
reason := causal.NewNode(causal.NodeReason, "Embedded, no server needed")
consequence := causal.NewNode(causal.NodeConsequence, "Single binary deployment")
require.NoError(t, repo.AddNode(ctx, decision))
require.NoError(t, repo.AddNode(ctx, reason))
require.NoError(t, repo.AddNode(ctx, consequence))
e1 := causal.NewEdge(reason.ID, decision.ID, causal.EdgeJustifies)
e2 := causal.NewEdge(decision.ID, consequence.ID, causal.EdgeCauses)
require.NoError(t, repo.AddEdge(ctx, e1))
require.NoError(t, repo.AddEdge(ctx, e2))
chain, err := repo.GetChain(ctx, "SQLite", 3)
require.NoError(t, err)
require.NotNil(t, chain)
assert.NotNil(t, chain.Decision)
assert.Equal(t, "Use SQLite", chain.Decision.Content)
assert.Len(t, chain.Reasons, 1)
assert.Len(t, chain.Consequences, 1)
}
func TestCausalRepo_AddNode_Duplicate(t *testing.T) {
repo := newTestCausalRepo(t)
ctx := context.Background()
node := causal.NewNode(causal.NodeDecision, "test")
require.NoError(t, repo.AddNode(ctx, node))
err := repo.AddNode(ctx, node)
assert.Error(t, err) // duplicate primary key
}
func TestCausalRepo_AddEdge_SelfLoop(t *testing.T) {
repo := newTestCausalRepo(t)
ctx := context.Background()
node := causal.NewNode(causal.NodeDecision, "test")
require.NoError(t, repo.AddNode(ctx, node))
edge := causal.NewEdge(node.ID, node.ID, causal.EdgeCauses)
err := repo.AddEdge(ctx, edge)
assert.Error(t, err) // self-loop validation
}
func TestCausalRepo_GetChain_NoResults(t *testing.T) {
repo := newTestCausalRepo(t)
ctx := context.Background()
chain, err := repo.GetChain(ctx, "nonexistent", 3)
require.NoError(t, err)
assert.Equal(t, 0, chain.TotalNodes)
}
func TestCausalRepo_Stats(t *testing.T) {
repo := newTestCausalRepo(t)
ctx := context.Background()
n1 := causal.NewNode(causal.NodeDecision, "D1")
n2 := causal.NewNode(causal.NodeReason, "R1")
n3 := causal.NewNode(causal.NodeConsequence, "C1")
require.NoError(t, repo.AddNode(ctx, n1))
require.NoError(t, repo.AddNode(ctx, n2))
require.NoError(t, repo.AddNode(ctx, n3))
e1 := causal.NewEdge(n2.ID, n1.ID, causal.EdgeJustifies)
require.NoError(t, repo.AddEdge(ctx, e1))
stats, err := repo.Stats(ctx)
require.NoError(t, err)
assert.Equal(t, 3, stats.TotalNodes)
assert.Equal(t, 1, stats.TotalEdges)
assert.Equal(t, 1, stats.ByType[causal.NodeDecision])
assert.Equal(t, 1, stats.ByType[causal.NodeReason])
assert.Equal(t, 1, stats.ByType[causal.NodeConsequence])
}
func TestCausalRepo_ComplexChain(t *testing.T) {
repo := newTestCausalRepo(t)
ctx := context.Background()
decision := causal.NewNode(causal.NodeDecision, "Use Go for MCP server")
r1 := causal.NewNode(causal.NodeReason, "Performance")
r2 := causal.NewNode(causal.NodeReason, "Single binary")
c1 := causal.NewNode(causal.NodeConsequence, "Faster startup")
cn1 := causal.NewNode(causal.NodeConstraint, "Must support CGO-free")
a1 := causal.NewNode(causal.NodeAlternative, "Use Rust")
for _, n := range []*causal.Node{decision, r1, r2, c1, cn1, a1} {
require.NoError(t, repo.AddNode(ctx, n))
}
edges := []*causal.Edge{
causal.NewEdge(r1.ID, decision.ID, causal.EdgeJustifies),
causal.NewEdge(r2.ID, decision.ID, causal.EdgeJustifies),
causal.NewEdge(decision.ID, c1.ID, causal.EdgeCauses),
causal.NewEdge(cn1.ID, decision.ID, causal.EdgeConstrains),
}
for _, e := range edges {
require.NoError(t, repo.AddEdge(ctx, e))
}
chain, err := repo.GetChain(ctx, "Go for MCP", 5)
require.NoError(t, err)
require.NotNil(t, chain.Decision)
assert.Equal(t, "Use Go for MCP server", chain.Decision.Content)
assert.Len(t, chain.Reasons, 2)
assert.Len(t, chain.Consequences, 1)
assert.Len(t, chain.Constraints, 1)
assert.GreaterOrEqual(t, chain.TotalNodes, 5)
}