gomcp/internal/infrastructure/sqlite/fact_repo_test.go

297 lines
8.1 KiB
Go

// Copyright 2026 Syntrex Lab. All rights reserved.
// Use of this source code is governed by an Apache-2.0 license
// that can be found in the LICENSE file.
package sqlite
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/syntrex-lab/gomcp/internal/domain/memory"
)
func newTestFactRepo(t *testing.T) *FactRepo {
t.Helper()
db, err := OpenMemory()
require.NoError(t, err)
t.Cleanup(func() { db.Close() })
repo, err := NewFactRepo(db)
require.NoError(t, err)
return repo
}
func TestFactRepo_Add_Get(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
fact := memory.NewFact("Go is fast", memory.LevelProject, "core", "engine")
fact.Confidence = 0.95
fact.Source = "manual"
fact.CodeRef = "main.go:42"
err := repo.Add(ctx, fact)
require.NoError(t, err)
got, err := repo.Get(ctx, fact.ID)
require.NoError(t, err)
require.NotNil(t, got)
assert.Equal(t, fact.ID, got.ID)
assert.Equal(t, fact.Content, got.Content)
assert.Equal(t, fact.Level, got.Level)
assert.Equal(t, fact.Domain, got.Domain)
assert.Equal(t, fact.Module, got.Module)
assert.Equal(t, fact.CodeRef, got.CodeRef)
assert.InDelta(t, fact.Confidence, got.Confidence, 0.001)
assert.Equal(t, fact.Source, got.Source)
assert.False(t, got.IsStale)
assert.False(t, got.IsArchived)
}
func TestFactRepo_Get_NotFound(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
got, err := repo.Get(ctx, "nonexistent")
assert.Error(t, err)
assert.Nil(t, got)
}
func TestFactRepo_Update(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
fact := memory.NewFact("original", memory.LevelProject, "core", "")
require.NoError(t, repo.Add(ctx, fact))
fact.Content = "updated"
fact.IsStale = true
require.NoError(t, repo.Update(ctx, fact))
got, err := repo.Get(ctx, fact.ID)
require.NoError(t, err)
assert.Equal(t, "updated", got.Content)
assert.True(t, got.IsStale)
}
func TestFactRepo_Delete(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
fact := memory.NewFact("to delete", memory.LevelProject, "", "")
require.NoError(t, repo.Add(ctx, fact))
err := repo.Delete(ctx, fact.ID)
require.NoError(t, err)
got, err := repo.Get(ctx, fact.ID)
assert.Error(t, err)
assert.Nil(t, got)
}
func TestFactRepo_ListByDomain(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
f1 := memory.NewFact("fact1", memory.LevelProject, "backend", "")
f2 := memory.NewFact("fact2", memory.LevelDomain, "backend", "")
f3 := memory.NewFact("fact3", memory.LevelProject, "frontend", "")
f4 := memory.NewFact("stale", memory.LevelProject, "backend", "")
f4.IsStale = true
for _, f := range []*memory.Fact{f1, f2, f3, f4} {
require.NoError(t, repo.Add(ctx, f))
}
// Without stale
facts, err := repo.ListByDomain(ctx, "backend", false)
require.NoError(t, err)
assert.Len(t, facts, 2)
// With stale
facts, err = repo.ListByDomain(ctx, "backend", true)
require.NoError(t, err)
assert.Len(t, facts, 3)
}
func TestFactRepo_ListByLevel(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
f1 := memory.NewFact("f1", memory.LevelProject, "", "")
f2 := memory.NewFact("f2", memory.LevelProject, "", "")
f3 := memory.NewFact("f3", memory.LevelDomain, "", "")
for _, f := range []*memory.Fact{f1, f2, f3} {
require.NoError(t, repo.Add(ctx, f))
}
facts, err := repo.ListByLevel(ctx, memory.LevelProject)
require.NoError(t, err)
assert.Len(t, facts, 2)
}
func TestFactRepo_ListDomains(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
f1 := memory.NewFact("f1", memory.LevelProject, "backend", "")
f2 := memory.NewFact("f2", memory.LevelProject, "frontend", "")
f3 := memory.NewFact("f3", memory.LevelProject, "backend", "")
for _, f := range []*memory.Fact{f1, f2, f3} {
require.NoError(t, repo.Add(ctx, f))
}
domains, err := repo.ListDomains(ctx)
require.NoError(t, err)
assert.Len(t, domains, 2)
assert.Contains(t, domains, "backend")
assert.Contains(t, domains, "frontend")
}
func TestFactRepo_GetStale(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
f1 := memory.NewFact("fresh", memory.LevelProject, "", "")
f2 := memory.NewFact("stale", memory.LevelProject, "", "")
f2.IsStale = true
f3 := memory.NewFact("archived", memory.LevelProject, "", "")
f3.IsStale = true
f3.IsArchived = true
for _, f := range []*memory.Fact{f1, f2, f3} {
require.NoError(t, repo.Add(ctx, f))
}
// Without archived
stale, err := repo.GetStale(ctx, false)
require.NoError(t, err)
assert.Len(t, stale, 1)
// With archived
stale, err = repo.GetStale(ctx, true)
require.NoError(t, err)
assert.Len(t, stale, 2)
}
func TestFactRepo_Search(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
f1 := memory.NewFact("Go concurrency patterns", memory.LevelProject, "", "")
f2 := memory.NewFact("Python is slow", memory.LevelProject, "", "")
f3 := memory.NewFact("Go channels are great", memory.LevelDomain, "", "")
for _, f := range []*memory.Fact{f1, f2, f3} {
require.NoError(t, repo.Add(ctx, f))
}
results, err := repo.Search(ctx, "Go", 10)
require.NoError(t, err)
assert.Len(t, results, 2)
}
func TestFactRepo_GetExpired(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
f1 := memory.NewFact("no ttl", memory.LevelProject, "", "")
f2 := memory.NewFact("expired", memory.LevelProject, "", "")
f2.TTL = &memory.TTLConfig{TTLSeconds: 1, OnExpire: memory.OnExpireMarkStale}
f2.CreatedAt = time.Now().Add(-2 * time.Hour)
f2.ValidFrom = f2.CreatedAt
for _, f := range []*memory.Fact{f1, f2} {
require.NoError(t, repo.Add(ctx, f))
}
expired, err := repo.GetExpired(ctx)
require.NoError(t, err)
assert.Len(t, expired, 1)
assert.Equal(t, f2.ID, expired[0].ID)
}
func TestFactRepo_RefreshTTL(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
f := memory.NewFact("refreshable", memory.LevelProject, "", "")
f.TTL = &memory.TTLConfig{TTLSeconds: 3600, OnExpire: memory.OnExpireMarkStale}
f.CreatedAt = time.Now().Add(-2 * time.Hour)
f.ValidFrom = f.CreatedAt
require.NoError(t, repo.Add(ctx, f))
require.NoError(t, repo.RefreshTTL(ctx, f.ID))
got, err := repo.Get(ctx, f.ID)
require.NoError(t, err)
assert.True(t, got.CreatedAt.After(f.CreatedAt))
}
func TestFactRepo_Stats(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
f1 := memory.NewFact("f1", memory.LevelProject, "backend", "")
f2 := memory.NewFact("f2", memory.LevelDomain, "backend", "")
f2.IsStale = true
f3 := memory.NewFact("f3", memory.LevelProject, "frontend", "")
f3.Embedding = []float64{0.1, 0.2}
for _, f := range []*memory.Fact{f1, f2, f3} {
require.NoError(t, repo.Add(ctx, f))
}
stats, err := repo.Stats(ctx)
require.NoError(t, err)
assert.Equal(t, 3, stats.TotalFacts)
assert.Equal(t, 2, stats.ByLevel[memory.LevelProject])
assert.Equal(t, 1, stats.ByLevel[memory.LevelDomain])
assert.Equal(t, 2, stats.ByDomain["backend"])
assert.Equal(t, 1, stats.ByDomain["frontend"])
assert.Equal(t, 1, stats.StaleCount)
assert.Equal(t, 1, stats.WithEmbeddings)
}
func TestFactRepo_EmbeddingRoundTrip(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
f := memory.NewFact("with embedding", memory.LevelProject, "", "")
f.Embedding = []float64{0.1, 0.2, 0.3, -0.5}
require.NoError(t, repo.Add(ctx, f))
got, err := repo.Get(ctx, f.ID)
require.NoError(t, err)
require.Len(t, got.Embedding, 4)
assert.InDelta(t, 0.1, got.Embedding[0], 0.0001)
assert.InDelta(t, -0.5, got.Embedding[3], 0.0001)
}
func TestFactRepo_TTLConfigRoundTrip(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
f := memory.NewFact("with ttl", memory.LevelProject, "", "")
f.TTL = &memory.TTLConfig{
TTLSeconds: 3600,
RefreshTrigger: "main.go",
OnExpire: memory.OnExpireArchive,
}
require.NoError(t, repo.Add(ctx, f))
got, err := repo.Get(ctx, f.ID)
require.NoError(t, err)
require.NotNil(t, got.TTL)
assert.Equal(t, 3600, got.TTL.TTLSeconds)
assert.Equal(t, "main.go", got.TTL.RefreshTrigger)
assert.Equal(t, memory.OnExpireArchive, got.TTL.OnExpire)
}