initial: Syntrex extraction from sentinel-community (615 files)

This commit is contained in:
DmitrL-dev 2026-03-11 15:12:02 +10:00
commit 2c50c993b1
175 changed files with 32396 additions and 0 deletions

View file

@ -0,0 +1,220 @@
package sqlite
import (
"context"
"fmt"
"time"
"github.com/sentinel-community/gomcp/internal/domain/causal"
)
// CausalRepo implements causal.CausalStore using SQLite.
// Compatible with causal_chains.db schema.
type CausalRepo struct {
db *DB
}
// NewCausalRepo creates a CausalRepo and ensures the schema exists.
func NewCausalRepo(db *DB) (*CausalRepo, error) {
repo := &CausalRepo{db: db}
if err := repo.migrate(); err != nil {
return nil, fmt.Errorf("causal repo migrate: %w", err)
}
return repo, nil
}
func (r *CausalRepo) migrate() error {
stmts := []string{
`CREATE TABLE IF NOT EXISTS causal_nodes (
id TEXT PRIMARY KEY,
node_type TEXT NOT NULL,
content TEXT NOT NULL,
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
session_id TEXT,
metadata TEXT DEFAULT '{}'
)`,
`CREATE TABLE IF NOT EXISTS causal_edges (
from_id TEXT NOT NULL,
to_id TEXT NOT NULL,
edge_type TEXT NOT NULL,
strength REAL DEFAULT 1.0,
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (from_id, to_id, edge_type),
FOREIGN KEY (from_id) REFERENCES causal_nodes(id),
FOREIGN KEY (to_id) REFERENCES causal_nodes(id)
)`,
`CREATE INDEX IF NOT EXISTS idx_nodes_type ON causal_nodes(node_type)`,
`CREATE INDEX IF NOT EXISTS idx_nodes_session ON causal_nodes(session_id)`,
`CREATE INDEX IF NOT EXISTS idx_edges_from ON causal_edges(from_id)`,
`CREATE INDEX IF NOT EXISTS idx_edges_to ON causal_edges(to_id)`,
}
for _, s := range stmts {
if _, err := r.db.Exec(s); err != nil {
return fmt.Errorf("exec migration: %w", err)
}
}
return nil
}
// AddNode inserts a causal node.
func (r *CausalRepo) AddNode(ctx context.Context, node *causal.Node) error {
if err := node.Validate(); err != nil {
return fmt.Errorf("validate node: %w", err)
}
_, err := r.db.Exec(`INSERT INTO causal_nodes (id, node_type, content, created_at)
VALUES (?, ?, ?, ?)`,
node.ID, string(node.Type), node.Content, node.CreatedAt.Format(timeFormat),
)
if err != nil {
return fmt.Errorf("insert node: %w", err)
}
return nil
}
// AddEdge inserts a causal edge.
func (r *CausalRepo) AddEdge(ctx context.Context, edge *causal.Edge) error {
if err := edge.Validate(); err != nil {
return fmt.Errorf("validate edge: %w", err)
}
_, err := r.db.Exec(`INSERT INTO causal_edges (from_id, to_id, edge_type)
VALUES (?, ?, ?)`,
edge.FromID, edge.ToID, string(edge.Type),
)
if err != nil {
return fmt.Errorf("insert edge: %w", err)
}
return nil
}
// GetChain builds a causal chain around a decision node matching the query.
func (r *CausalRepo) GetChain(ctx context.Context, query string, maxDepth int) (*causal.Chain, error) {
chain := &causal.Chain{}
// Find decision node matching query.
row := r.db.QueryRow(`SELECT id, node_type, content, created_at
FROM causal_nodes WHERE node_type = 'decision' AND content LIKE ? LIMIT 1`,
"%"+query+"%")
var id, nodeType, content, createdAt string
err := row.Scan(&id, &nodeType, &content, &createdAt)
if err != nil {
// No decision found — return empty chain.
return chain, nil
}
t, _ := time.Parse(timeFormat, createdAt)
chain.Decision = &causal.Node{ID: id, Type: causal.NodeType(nodeType), Content: content, CreatedAt: t}
chain.TotalNodes = 1
// Find all connected nodes via edges.
// Incoming edges (nodes that point TO the decision).
inRows, err := r.db.Query(`SELECT n.id, n.node_type, n.content, n.created_at, e.edge_type
FROM causal_edges e JOIN causal_nodes n ON e.from_id = n.id
WHERE e.to_id = ?`, id)
if err != nil {
return nil, fmt.Errorf("query incoming edges: %w", err)
}
defer inRows.Close()
for inRows.Next() {
var nid, nt, nc, nca, et string
if err := inRows.Scan(&nid, &nt, &nc, &nca, &et); err != nil {
return nil, fmt.Errorf("scan incoming: %w", err)
}
tt, _ := time.Parse(timeFormat, nca)
node := &causal.Node{ID: nid, Type: causal.NodeType(nt), Content: nc, CreatedAt: tt}
chain.TotalNodes++
switch causal.EdgeType(et) {
case causal.EdgeJustifies:
chain.Reasons = append(chain.Reasons, node)
case causal.EdgeConstrains:
chain.Constraints = append(chain.Constraints, node)
default:
// Classify by node type if edge type doesn't match.
switch causal.NodeType(nt) {
case causal.NodeAlternative:
chain.Alternatives = append(chain.Alternatives, node)
case causal.NodeReason:
chain.Reasons = append(chain.Reasons, node)
case causal.NodeConstraint:
chain.Constraints = append(chain.Constraints, node)
}
}
}
if err := inRows.Err(); err != nil {
return nil, err
}
// Outgoing edges (nodes that the decision points TO).
outRows, err := r.db.Query(`SELECT n.id, n.node_type, n.content, n.created_at, e.edge_type
FROM causal_edges e JOIN causal_nodes n ON e.to_id = n.id
WHERE e.from_id = ?`, id)
if err != nil {
return nil, fmt.Errorf("query outgoing edges: %w", err)
}
defer outRows.Close()
for outRows.Next() {
var nid, nt, nc, nca, et string
if err := outRows.Scan(&nid, &nt, &nc, &nca, &et); err != nil {
return nil, fmt.Errorf("scan outgoing: %w", err)
}
tt, _ := time.Parse(timeFormat, nca)
node := &causal.Node{ID: nid, Type: causal.NodeType(nt), Content: nc, CreatedAt: tt}
chain.TotalNodes++
switch causal.EdgeType(et) {
case causal.EdgeCauses:
chain.Consequences = append(chain.Consequences, node)
default:
switch causal.NodeType(nt) {
case causal.NodeConsequence:
chain.Consequences = append(chain.Consequences, node)
case causal.NodeAlternative:
chain.Alternatives = append(chain.Alternatives, node)
}
}
}
if err := outRows.Err(); err != nil {
return nil, err
}
return chain, nil
}
// Stats returns aggregate statistics about the causal store.
func (r *CausalRepo) Stats(ctx context.Context) (*causal.CausalStats, error) {
stats := &causal.CausalStats{
ByType: make(map[causal.NodeType]int),
}
row := r.db.QueryRow(`SELECT COUNT(*) FROM causal_nodes`)
if err := row.Scan(&stats.TotalNodes); err != nil {
return nil, err
}
row = r.db.QueryRow(`SELECT COUNT(*) FROM causal_edges`)
if err := row.Scan(&stats.TotalEdges); err != nil {
return nil, err
}
rows, err := r.db.Query(`SELECT node_type, COUNT(*) FROM causal_nodes GROUP BY node_type`)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var nt string
var count int
if err := rows.Scan(&nt, &count); err != nil {
return nil, err
}
stats.ByType[causal.NodeType(nt)] = count
}
return stats, rows.Err()
}
// Ensure CausalRepo implements causal.CausalStore.
var _ causal.CausalStore = (*CausalRepo)(nil)

View file

@ -0,0 +1,137 @@
package sqlite
import (
"context"
"testing"
"github.com/sentinel-community/gomcp/internal/domain/causal"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newTestCausalRepo(t *testing.T) *CausalRepo {
t.Helper()
db, err := OpenMemory()
require.NoError(t, err)
t.Cleanup(func() { db.Close() })
repo, err := NewCausalRepo(db)
require.NoError(t, err)
return repo
}
func TestCausalRepo_AddNode_GetChain(t *testing.T) {
repo := newTestCausalRepo(t)
ctx := context.Background()
decision := causal.NewNode(causal.NodeDecision, "Use SQLite")
reason := causal.NewNode(causal.NodeReason, "Embedded, no server needed")
consequence := causal.NewNode(causal.NodeConsequence, "Single binary deployment")
require.NoError(t, repo.AddNode(ctx, decision))
require.NoError(t, repo.AddNode(ctx, reason))
require.NoError(t, repo.AddNode(ctx, consequence))
e1 := causal.NewEdge(reason.ID, decision.ID, causal.EdgeJustifies)
e2 := causal.NewEdge(decision.ID, consequence.ID, causal.EdgeCauses)
require.NoError(t, repo.AddEdge(ctx, e1))
require.NoError(t, repo.AddEdge(ctx, e2))
chain, err := repo.GetChain(ctx, "SQLite", 3)
require.NoError(t, err)
require.NotNil(t, chain)
assert.NotNil(t, chain.Decision)
assert.Equal(t, "Use SQLite", chain.Decision.Content)
assert.Len(t, chain.Reasons, 1)
assert.Len(t, chain.Consequences, 1)
}
func TestCausalRepo_AddNode_Duplicate(t *testing.T) {
repo := newTestCausalRepo(t)
ctx := context.Background()
node := causal.NewNode(causal.NodeDecision, "test")
require.NoError(t, repo.AddNode(ctx, node))
err := repo.AddNode(ctx, node)
assert.Error(t, err) // duplicate primary key
}
func TestCausalRepo_AddEdge_SelfLoop(t *testing.T) {
repo := newTestCausalRepo(t)
ctx := context.Background()
node := causal.NewNode(causal.NodeDecision, "test")
require.NoError(t, repo.AddNode(ctx, node))
edge := causal.NewEdge(node.ID, node.ID, causal.EdgeCauses)
err := repo.AddEdge(ctx, edge)
assert.Error(t, err) // self-loop validation
}
func TestCausalRepo_GetChain_NoResults(t *testing.T) {
repo := newTestCausalRepo(t)
ctx := context.Background()
chain, err := repo.GetChain(ctx, "nonexistent", 3)
require.NoError(t, err)
assert.Equal(t, 0, chain.TotalNodes)
}
func TestCausalRepo_Stats(t *testing.T) {
repo := newTestCausalRepo(t)
ctx := context.Background()
n1 := causal.NewNode(causal.NodeDecision, "D1")
n2 := causal.NewNode(causal.NodeReason, "R1")
n3 := causal.NewNode(causal.NodeConsequence, "C1")
require.NoError(t, repo.AddNode(ctx, n1))
require.NoError(t, repo.AddNode(ctx, n2))
require.NoError(t, repo.AddNode(ctx, n3))
e1 := causal.NewEdge(n2.ID, n1.ID, causal.EdgeJustifies)
require.NoError(t, repo.AddEdge(ctx, e1))
stats, err := repo.Stats(ctx)
require.NoError(t, err)
assert.Equal(t, 3, stats.TotalNodes)
assert.Equal(t, 1, stats.TotalEdges)
assert.Equal(t, 1, stats.ByType[causal.NodeDecision])
assert.Equal(t, 1, stats.ByType[causal.NodeReason])
assert.Equal(t, 1, stats.ByType[causal.NodeConsequence])
}
func TestCausalRepo_ComplexChain(t *testing.T) {
repo := newTestCausalRepo(t)
ctx := context.Background()
decision := causal.NewNode(causal.NodeDecision, "Use Go for MCP server")
r1 := causal.NewNode(causal.NodeReason, "Performance")
r2 := causal.NewNode(causal.NodeReason, "Single binary")
c1 := causal.NewNode(causal.NodeConsequence, "Faster startup")
cn1 := causal.NewNode(causal.NodeConstraint, "Must support CGO-free")
a1 := causal.NewNode(causal.NodeAlternative, "Use Rust")
for _, n := range []*causal.Node{decision, r1, r2, c1, cn1, a1} {
require.NoError(t, repo.AddNode(ctx, n))
}
edges := []*causal.Edge{
causal.NewEdge(r1.ID, decision.ID, causal.EdgeJustifies),
causal.NewEdge(r2.ID, decision.ID, causal.EdgeJustifies),
causal.NewEdge(decision.ID, c1.ID, causal.EdgeCauses),
causal.NewEdge(cn1.ID, decision.ID, causal.EdgeConstrains),
}
for _, e := range edges {
require.NoError(t, repo.AddEdge(ctx, e))
}
chain, err := repo.GetChain(ctx, "Go for MCP", 5)
require.NoError(t, err)
require.NotNil(t, chain.Decision)
assert.Equal(t, "Use Go for MCP server", chain.Decision.Content)
assert.Len(t, chain.Reasons, 2)
assert.Len(t, chain.Consequences, 1)
assert.Len(t, chain.Constraints, 1)
assert.GreaterOrEqual(t, chain.TotalNodes, 5)
}

View file

@ -0,0 +1,254 @@
package sqlite
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"path/filepath"
"strings"
"github.com/sentinel-community/gomcp/internal/domain/crystal"
)
// CrystalRepo implements crystal.CrystalStore using SQLite.
// Compatible with crystals.db schema.
type CrystalRepo struct {
db *DB
}
// NewCrystalRepo creates a CrystalRepo and ensures the schema exists.
func NewCrystalRepo(db *DB) (*CrystalRepo, error) {
repo := &CrystalRepo{db: db}
if err := repo.migrate(); err != nil {
return nil, fmt.Errorf("crystal repo migrate: %w", err)
}
return repo, nil
}
func (r *CrystalRepo) migrate() error {
stmts := []string{
`CREATE TABLE IF NOT EXISTS crystals (
path TEXT PRIMARY KEY,
name TEXT,
content BLOB,
primitives_count INTEGER,
token_count INTEGER,
indexed_at REAL,
source_mtime REAL,
source_hash TEXT,
last_validated REAL,
human_confirmed INTEGER DEFAULT 0
)`,
`CREATE TABLE IF NOT EXISTS metadata (
key TEXT PRIMARY KEY,
value TEXT
)`,
`CREATE INDEX IF NOT EXISTS idx_mtime ON crystals(source_mtime)`,
`CREATE INDEX IF NOT EXISTS idx_indexed ON crystals(indexed_at)`,
}
for _, s := range stmts {
if _, err := r.db.Exec(s); err != nil {
return fmt.Errorf("exec migration: %w", err)
}
}
return nil
}
// serializedCrystal is the JSON structure stored in the content BLOB.
type serializedCrystal struct {
Path string `json:"path"`
Name string `json:"name"`
TokenCount int `json:"token_count"`
ContentHash string `json:"content_hash"`
Primitives []crystal.Primitive `json:"primitives"`
}
// Upsert inserts or replaces a crystal.
func (r *CrystalRepo) Upsert(ctx context.Context, c *crystal.Crystal) error {
sc := serializedCrystal{
Path: c.Path,
Name: c.Name,
TokenCount: c.TokenCount,
ContentHash: c.ContentHash,
Primitives: c.Primitives,
}
blob, err := json.Marshal(sc)
if err != nil {
return fmt.Errorf("marshal crystal: %w", err)
}
_, err = r.db.Exec(`INSERT OR REPLACE INTO crystals
(path, name, content, primitives_count, token_count,
indexed_at, source_mtime, source_hash, last_validated, human_confirmed)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
c.Path, c.Name, blob, c.PrimitivesCount, c.TokenCount,
c.IndexedAt, c.SourceMtime, c.SourceHash,
c.LastValidated, boolToInt(c.HumanConfirmed),
)
if err != nil {
return fmt.Errorf("upsert crystal: %w", err)
}
return nil
}
// Get retrieves a crystal by path.
func (r *CrystalRepo) Get(ctx context.Context, path string) (*crystal.Crystal, error) {
row := r.db.QueryRow(`SELECT path, name, content, primitives_count, token_count,
indexed_at, source_mtime, source_hash, last_validated, human_confirmed
FROM crystals WHERE path = ?`, path)
var c crystal.Crystal
var blob []byte
var lastValidated sql.NullFloat64
var humanConfirmed int
err := row.Scan(&c.Path, &c.Name, &blob, &c.PrimitivesCount, &c.TokenCount,
&c.IndexedAt, &c.SourceMtime, &c.SourceHash, &lastValidated, &humanConfirmed)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("crystal %s not found", path)
}
return nil, fmt.Errorf("scan crystal: %w", err)
}
if lastValidated.Valid {
c.LastValidated = lastValidated.Float64
}
c.HumanConfirmed = humanConfirmed != 0
if len(blob) > 0 {
var sc serializedCrystal
if err := json.Unmarshal(blob, &sc); err != nil {
return nil, fmt.Errorf("unmarshal crystal content: %w", err)
}
c.Primitives = sc.Primitives
c.ContentHash = sc.ContentHash
}
return &c, nil
}
// Delete removes a crystal by path.
func (r *CrystalRepo) Delete(ctx context.Context, path string) error {
result, err := r.db.Exec(`DELETE FROM crystals WHERE path = ?`, path)
if err != nil {
return fmt.Errorf("delete crystal: %w", err)
}
n, _ := result.RowsAffected()
if n == 0 {
return fmt.Errorf("crystal %s not found", path)
}
return nil
}
// List returns crystals matching a path pattern. Empty pattern returns all.
func (r *CrystalRepo) List(ctx context.Context, pattern string, limit int) ([]*crystal.Crystal, error) {
if limit <= 0 {
limit = 100
}
var rows *sql.Rows
var err error
if pattern == "" {
rows, err = r.db.Query(`SELECT path, name, content, primitives_count, token_count,
indexed_at, source_mtime, source_hash, last_validated, human_confirmed
FROM crystals ORDER BY indexed_at DESC LIMIT ?`, limit)
} else {
rows, err = r.db.Query(`SELECT path, name, content, primitives_count, token_count,
indexed_at, source_mtime, source_hash, last_validated, human_confirmed
FROM crystals WHERE path LIKE ? ORDER BY indexed_at DESC LIMIT ?`, pattern, limit)
}
if err != nil {
return nil, fmt.Errorf("list crystals: %w", err)
}
defer rows.Close()
return scanCrystals(rows)
}
// Search searches crystal primitives by name/value containing query.
func (r *CrystalRepo) Search(ctx context.Context, query string, limit int) ([]*crystal.Crystal, error) {
if limit <= 0 {
limit = 50
}
// Search in content BLOB (JSON text stored as BLOB) for primitive names/values.
// CAST to TEXT required because SQLite LIKE doesn't match on raw BLOB.
rows, err := r.db.Query(`SELECT path, name, content, primitives_count, token_count,
indexed_at, source_mtime, source_hash, last_validated, human_confirmed
FROM crystals WHERE CAST(content AS TEXT) LIKE ? OR name LIKE ? LIMIT ?`,
"%"+query+"%", "%"+query+"%", limit)
if err != nil {
return nil, fmt.Errorf("search crystals: %w", err)
}
defer rows.Close()
return scanCrystals(rows)
}
// Stats returns aggregate statistics.
func (r *CrystalRepo) Stats(ctx context.Context) (*crystal.CrystalStats, error) {
stats := &crystal.CrystalStats{
ByExtension: make(map[string]int),
}
row := r.db.QueryRow(`SELECT COUNT(*), COALESCE(SUM(primitives_count),0), COALESCE(SUM(token_count),0) FROM crystals`)
if err := row.Scan(&stats.TotalCrystals, &stats.TotalPrimitives, &stats.TotalTokens); err != nil {
return nil, err
}
// Count by file extension.
rows, err := r.db.Query(`SELECT path FROM crystals`)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var p string
if err := rows.Scan(&p); err != nil {
return nil, err
}
ext := strings.ToLower(filepath.Ext(p))
if ext == "" {
ext = "(no ext)"
}
stats.ByExtension[ext]++
}
return stats, rows.Err()
}
func scanCrystals(rows *sql.Rows) ([]*crystal.Crystal, error) {
var result []*crystal.Crystal
for rows.Next() {
var c crystal.Crystal
var blob []byte
var lastValidated sql.NullFloat64
var humanConfirmed int
err := rows.Scan(&c.Path, &c.Name, &blob, &c.PrimitivesCount, &c.TokenCount,
&c.IndexedAt, &c.SourceMtime, &c.SourceHash, &lastValidated, &humanConfirmed)
if err != nil {
return nil, fmt.Errorf("scan crystal row: %w", err)
}
if lastValidated.Valid {
c.LastValidated = lastValidated.Float64
}
c.HumanConfirmed = humanConfirmed != 0
if len(blob) > 0 {
var sc serializedCrystal
if err := json.Unmarshal(blob, &sc); err != nil {
return nil, fmt.Errorf("unmarshal crystal content: %w", err)
}
c.Primitives = sc.Primitives
c.ContentHash = sc.ContentHash
}
result = append(result, &c)
}
return result, rows.Err()
}
// Ensure CrystalRepo implements crystal.CrystalStore.
var _ crystal.CrystalStore = (*CrystalRepo)(nil)

View file

@ -0,0 +1,160 @@
package sqlite
import (
"context"
"testing"
"github.com/sentinel-community/gomcp/internal/domain/crystal"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newTestCrystalRepo(t *testing.T) *CrystalRepo {
t.Helper()
db, err := OpenMemory()
require.NoError(t, err)
t.Cleanup(func() { db.Close() })
repo, err := NewCrystalRepo(db)
require.NoError(t, err)
return repo
}
func makeCrystal(path, name string) *crystal.Crystal {
return &crystal.Crystal{
Path: path,
Name: name,
TokenCount: 150,
ContentHash: "hash123",
PrimitivesCount: 2,
Primitives: []crystal.Primitive{
{PType: "function", Name: "main", Value: "func main()", SourceLine: 1, Confidence: 1.0},
{PType: "function", Name: "init", Value: "func init()", SourceLine: 5, Confidence: 0.9},
},
IndexedAt: 1700000000.0,
SourceMtime: 1699999000.0,
SourceHash: "src_hash",
}
}
func TestCrystalRepo_Upsert_Get(t *testing.T) {
repo := newTestCrystalRepo(t)
ctx := context.Background()
c := makeCrystal("cmd/main.go", "main.go")
require.NoError(t, repo.Upsert(ctx, c))
got, err := repo.Get(ctx, "cmd/main.go")
require.NoError(t, err)
require.NotNil(t, got)
assert.Equal(t, "cmd/main.go", got.Path)
assert.Equal(t, "main.go", got.Name)
assert.Equal(t, 150, got.TokenCount)
assert.Equal(t, 2, got.PrimitivesCount)
assert.Len(t, got.Primitives, 2)
assert.Equal(t, "function", got.Primitives[0].PType)
}
func TestCrystalRepo_Upsert_Overwrite(t *testing.T) {
repo := newTestCrystalRepo(t)
ctx := context.Background()
c := makeCrystal("main.go", "main.go")
require.NoError(t, repo.Upsert(ctx, c))
c.TokenCount = 300
c.PrimitivesCount = 5
require.NoError(t, repo.Upsert(ctx, c))
got, err := repo.Get(ctx, "main.go")
require.NoError(t, err)
assert.Equal(t, 300, got.TokenCount)
assert.Equal(t, 5, got.PrimitivesCount)
}
func TestCrystalRepo_Get_NotFound(t *testing.T) {
repo := newTestCrystalRepo(t)
ctx := context.Background()
got, err := repo.Get(ctx, "nonexistent.go")
assert.Error(t, err)
assert.Nil(t, got)
}
func TestCrystalRepo_Delete(t *testing.T) {
repo := newTestCrystalRepo(t)
ctx := context.Background()
c := makeCrystal("delete_me.go", "delete_me.go")
require.NoError(t, repo.Upsert(ctx, c))
require.NoError(t, repo.Delete(ctx, "delete_me.go"))
got, err := repo.Get(ctx, "delete_me.go")
assert.Error(t, err)
assert.Nil(t, got)
}
func TestCrystalRepo_List(t *testing.T) {
repo := newTestCrystalRepo(t)
ctx := context.Background()
for _, p := range []string{"cmd/main.go", "internal/foo.go", "internal/bar.go", "README.md"} {
require.NoError(t, repo.Upsert(ctx, makeCrystal(p, p)))
}
// List all
all, err := repo.List(ctx, "", 100)
require.NoError(t, err)
assert.Len(t, all, 4)
// List with pattern
internal, err := repo.List(ctx, "internal%", 100)
require.NoError(t, err)
assert.Len(t, internal, 2)
}
func TestCrystalRepo_Search(t *testing.T) {
repo := newTestCrystalRepo(t)
ctx := context.Background()
c1 := makeCrystal("server.go", "server.go")
c1.Primitives = []crystal.Primitive{
{PType: "function", Name: "handleRequest", Value: "func handleRequest()", SourceLine: 10, Confidence: 1.0},
}
c2 := makeCrystal("client.go", "client.go")
c2.Primitives = []crystal.Primitive{
{PType: "function", Name: "sendRequest", Value: "func sendRequest()", SourceLine: 5, Confidence: 1.0},
}
c3 := makeCrystal("utils.go", "utils.go")
c3.Primitives = []crystal.Primitive{
{PType: "function", Name: "helper", Value: "func helper()", SourceLine: 1, Confidence: 1.0},
}
for _, c := range []*crystal.Crystal{c1, c2, c3} {
require.NoError(t, repo.Upsert(ctx, c))
}
results, err := repo.Search(ctx, "Request", 10)
require.NoError(t, err)
assert.Len(t, results, 2)
}
func TestCrystalRepo_Stats(t *testing.T) {
repo := newTestCrystalRepo(t)
ctx := context.Background()
c1 := makeCrystal("main.go", "main.go")
c1.TokenCount = 100
c2 := makeCrystal("server.py", "server.py")
c2.TokenCount = 200
c2.PrimitivesCount = 5
require.NoError(t, repo.Upsert(ctx, c1))
require.NoError(t, repo.Upsert(ctx, c2))
stats, err := repo.Stats(ctx)
require.NoError(t, err)
assert.Equal(t, 2, stats.TotalCrystals)
assert.Equal(t, 300, stats.TotalTokens)
}

View file

@ -0,0 +1,105 @@
// Package sqlite provides SQLite-based persistence using modernc.org/sqlite (pure Go, no CGO).
package sqlite
import (
"database/sql"
"fmt"
"os"
"path/filepath"
"sync"
_ "modernc.org/sqlite"
)
// DB wraps a *sql.DB with SQLite-specific configuration.
type DB struct {
db *sql.DB
path string
mu sync.RWMutex
}
// Open opens or creates an SQLite database at the given path.
// It applies WAL mode and recommended pragmas for performance.
func Open(path string) (*DB, error) {
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, fmt.Errorf("create db directory: %w", err)
}
db, err := sql.Open("sqlite", path)
if err != nil {
return nil, fmt.Errorf("open sqlite: %w", err)
}
// Apply performance pragmas.
pragmas := []string{
"PRAGMA journal_mode=WAL",
"PRAGMA synchronous=NORMAL",
"PRAGMA cache_size=-64000", // 64MB
"PRAGMA foreign_keys=ON",
"PRAGMA busy_timeout=5000",
}
for _, p := range pragmas {
if _, err := db.Exec(p); err != nil {
db.Close()
return nil, fmt.Errorf("exec pragma %q: %w", p, err)
}
}
// Connection pool settings for SQLite (single writer).
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
return &DB{db: db, path: path}, nil
}
// OpenMemory opens an in-memory SQLite database (for testing).
func OpenMemory() (*DB, error) {
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
return nil, fmt.Errorf("open in-memory sqlite: %w", err)
}
if _, err := db.Exec("PRAGMA foreign_keys=ON"); err != nil {
db.Close()
return nil, fmt.Errorf("enable foreign keys: %w", err)
}
return &DB{db: db, path: ":memory:"}, nil
}
// SqlDB returns the underlying *sql.DB.
func (d *DB) SqlDB() *sql.DB {
return d.db
}
// Path returns the database file path.
func (d *DB) Path() string {
return d.path
}
// Close closes the database connection.
func (d *DB) Close() error {
return d.db.Close()
}
// Exec executes a query that doesn't return rows.
func (d *DB) Exec(query string, args ...any) (sql.Result, error) {
d.mu.Lock()
defer d.mu.Unlock()
return d.db.Exec(query, args...)
}
// Query executes a query that returns rows.
func (d *DB) Query(query string, args ...any) (*sql.Rows, error) {
d.mu.RLock()
defer d.mu.RUnlock()
return d.db.Query(query, args...)
}
// QueryRow executes a query that returns at most one row.
func (d *DB) QueryRow(query string, args ...any) *sql.Row {
d.mu.RLock()
defer d.mu.RUnlock()
return d.db.QueryRow(query, args...)
}

View file

@ -0,0 +1,698 @@
package sqlite
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/sentinel-community/gomcp/internal/domain/memory"
)
const timeFormat = time.RFC3339Nano
// FactRepo implements memory.FactStore using SQLite.
// Compatible with memory_bridge_v2.db schema v2.0.0.
type FactRepo struct {
db *DB
}
// NewFactRepo creates a FactRepo and ensures the schema exists.
func NewFactRepo(db *DB) (*FactRepo, error) {
repo := &FactRepo{db: db}
if err := repo.migrate(); err != nil {
return nil, fmt.Errorf("fact repo migrate: %w", err)
}
return repo, nil
}
func (r *FactRepo) migrate() error {
stmts := []string{
`CREATE TABLE IF NOT EXISTS hierarchical_facts (
id TEXT PRIMARY KEY,
content TEXT NOT NULL,
level INTEGER NOT NULL DEFAULT 0,
domain TEXT,
module TEXT,
code_ref TEXT,
parent_id TEXT,
embedding BLOB,
ttl_config TEXT,
created_at TEXT NOT NULL,
valid_from TEXT NOT NULL,
valid_until TEXT,
is_stale INTEGER DEFAULT 0,
is_archived INTEGER DEFAULT 0,
confidence REAL DEFAULT 1.0,
source TEXT DEFAULT 'manual',
session_id TEXT,
FOREIGN KEY (parent_id) REFERENCES hierarchical_facts(id)
)`,
`CREATE TABLE IF NOT EXISTS fact_hierarchy (
parent_id TEXT NOT NULL,
child_id TEXT NOT NULL,
relationship TEXT DEFAULT 'contains',
PRIMARY KEY (parent_id, child_id),
FOREIGN KEY (parent_id) REFERENCES hierarchical_facts(id),
FOREIGN KEY (child_id) REFERENCES hierarchical_facts(id)
)`,
`CREATE TABLE IF NOT EXISTS embeddings_index (
fact_id TEXT PRIMARY KEY,
embedding BLOB NOT NULL,
model_name TEXT DEFAULT 'all-MiniLM-L6-v2',
updated_at TEXT DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (fact_id) REFERENCES hierarchical_facts(id)
)`,
`CREATE TABLE IF NOT EXISTS domain_centroids (
domain TEXT PRIMARY KEY,
centroid BLOB NOT NULL,
fact_count INTEGER DEFAULT 0,
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
)`,
`CREATE TABLE IF NOT EXISTS schema_info (
key TEXT PRIMARY KEY,
value TEXT
)`,
`CREATE INDEX IF NOT EXISTS idx_facts_level ON hierarchical_facts(level)`,
`CREATE INDEX IF NOT EXISTS idx_facts_domain ON hierarchical_facts(domain)`,
`CREATE INDEX IF NOT EXISTS idx_facts_module ON hierarchical_facts(module)`,
`CREATE INDEX IF NOT EXISTS idx_facts_stale ON hierarchical_facts(is_stale)`,
`CREATE INDEX IF NOT EXISTS idx_facts_session ON hierarchical_facts(session_id)`,
`CREATE INDEX IF NOT EXISTS idx_facts_source ON hierarchical_facts(source)`,
`INSERT OR REPLACE INTO schema_info (key, value) VALUES ('version', '3.3.0')`,
}
// v3.3 migration: add hit_count, last_accessed_at, synapses table.
v33Stmts := []string{
// Safe ALTER TABLE — ignore error if column already exists.
`ALTER TABLE hierarchical_facts ADD COLUMN hit_count INTEGER DEFAULT 0`,
`ALTER TABLE hierarchical_facts ADD COLUMN last_accessed_at TEXT`,
`CREATE TABLE IF NOT EXISTS synapses (
id INTEGER PRIMARY KEY AUTOINCREMENT,
fact_id_a TEXT NOT NULL,
fact_id_b TEXT NOT NULL,
confidence REAL DEFAULT 0.0,
status TEXT DEFAULT 'PENDING',
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (fact_id_a) REFERENCES hierarchical_facts(id),
FOREIGN KEY (fact_id_b) REFERENCES hierarchical_facts(id)
)`,
`CREATE INDEX IF NOT EXISTS idx_facts_hit_count ON hierarchical_facts(hit_count)`,
`CREATE INDEX IF NOT EXISTS idx_synapses_status ON synapses(status)`,
}
for _, s := range stmts {
if _, err := r.db.Exec(s); err != nil {
return fmt.Errorf("exec %q: %w", s[:40], err)
}
}
// v3.3: ignore errors on ALTER TABLE (column may already exist).
for _, s := range v33Stmts {
_, _ = r.db.Exec(s)
}
return nil
}
// Add inserts a new fact.
func (r *FactRepo) Add(ctx context.Context, fact *memory.Fact) error {
embeddingBlob, err := encodeEmbedding(fact.Embedding)
if err != nil {
return err
}
ttlJSON, err := encodeTTL(fact.TTL)
if err != nil {
return err
}
var validUntil *string
if fact.ValidUntil != nil {
s := fact.ValidUntil.Format(timeFormat)
validUntil = &s
}
_, err = r.db.Exec(`INSERT INTO hierarchical_facts
(id, content, level, domain, module, code_ref, parent_id,
embedding, ttl_config, created_at, valid_from, valid_until,
is_stale, is_archived, confidence, source, session_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
fact.ID, fact.Content, int(fact.Level),
nullStr(fact.Domain), nullStr(fact.Module), nullStr(fact.CodeRef),
nullStr(fact.ParentID),
embeddingBlob, ttlJSON,
fact.CreatedAt.Format(timeFormat), fact.ValidFrom.Format(timeFormat), validUntil,
boolToInt(fact.IsStale), boolToInt(fact.IsArchived),
fact.Confidence, fact.Source, nullStr(fact.SessionID),
)
if err != nil {
return fmt.Errorf("insert fact: %w", err)
}
// Also insert into embeddings_index if embedding exists.
if len(fact.Embedding) > 0 {
_, err = r.db.Exec(`INSERT OR REPLACE INTO embeddings_index (fact_id, embedding) VALUES (?, ?)`,
fact.ID, embeddingBlob)
if err != nil {
return fmt.Errorf("insert embedding index: %w", err)
}
}
return nil
}
// Get retrieves a fact by ID.
func (r *FactRepo) Get(ctx context.Context, id string) (*memory.Fact, error) {
row := r.db.QueryRow(`SELECT id, content, level, domain, module, code_ref, parent_id,
embedding, ttl_config, created_at, valid_from, valid_until,
is_stale, is_archived, confidence, source, session_id
FROM hierarchical_facts WHERE id = ?`, id)
return scanFact(row)
}
// Update updates an existing fact.
func (r *FactRepo) Update(ctx context.Context, fact *memory.Fact) error {
embeddingBlob, err := encodeEmbedding(fact.Embedding)
if err != nil {
return err
}
ttlJSON, err := encodeTTL(fact.TTL)
if err != nil {
return err
}
var validUntil *string
if fact.ValidUntil != nil {
s := fact.ValidUntil.Format(timeFormat)
validUntil = &s
}
result, err := r.db.Exec(`UPDATE hierarchical_facts SET
content=?, level=?, domain=?, module=?, code_ref=?, parent_id=?,
embedding=?, ttl_config=?, created_at=?, valid_from=?, valid_until=?,
is_stale=?, is_archived=?, confidence=?, source=?, session_id=?
WHERE id=?`,
fact.Content, int(fact.Level),
nullStr(fact.Domain), nullStr(fact.Module), nullStr(fact.CodeRef),
nullStr(fact.ParentID),
embeddingBlob, ttlJSON,
fact.CreatedAt.Format(timeFormat), fact.ValidFrom.Format(timeFormat), validUntil,
boolToInt(fact.IsStale), boolToInt(fact.IsArchived),
fact.Confidence, fact.Source, nullStr(fact.SessionID),
fact.ID,
)
if err != nil {
return fmt.Errorf("update fact: %w", err)
}
n, _ := result.RowsAffected()
if n == 0 {
return fmt.Errorf("fact %s not found", fact.ID)
}
return nil
}
// Delete removes a fact by ID.
func (r *FactRepo) Delete(ctx context.Context, id string) error {
// Remove from embeddings_index first (FK).
_, _ = r.db.Exec(`DELETE FROM embeddings_index WHERE fact_id = ?`, id)
// Remove from fact_hierarchy.
_, _ = r.db.Exec(`DELETE FROM fact_hierarchy WHERE parent_id = ? OR child_id = ?`, id, id)
// Remove the fact.
result, err := r.db.Exec(`DELETE FROM hierarchical_facts WHERE id = ?`, id)
if err != nil {
return fmt.Errorf("delete fact: %w", err)
}
n, _ := result.RowsAffected()
if n == 0 {
return fmt.Errorf("fact %s not found", id)
}
return nil
}
// ListByDomain returns facts in a domain, optionally including stale ones.
func (r *FactRepo) ListByDomain(ctx context.Context, domain string, includeStale bool) ([]*memory.Fact, error) {
var query string
if includeStale {
query = `SELECT id, content, level, domain, module, code_ref, parent_id,
embedding, ttl_config, created_at, valid_from, valid_until,
is_stale, is_archived, confidence, source, session_id
FROM hierarchical_facts WHERE domain = ?`
} else {
query = `SELECT id, content, level, domain, module, code_ref, parent_id,
embedding, ttl_config, created_at, valid_from, valid_until,
is_stale, is_archived, confidence, source, session_id
FROM hierarchical_facts WHERE domain = ? AND is_stale = 0`
}
rows, err := r.db.Query(query, domain)
if err != nil {
return nil, fmt.Errorf("list by domain: %w", err)
}
defer rows.Close()
return scanFacts(rows)
}
// ListByLevel returns all facts at a given hierarchy level.
func (r *FactRepo) ListByLevel(ctx context.Context, level memory.HierLevel) ([]*memory.Fact, error) {
rows, err := r.db.Query(`SELECT id, content, level, domain, module, code_ref, parent_id,
embedding, ttl_config, created_at, valid_from, valid_until,
is_stale, is_archived, confidence, source, session_id
FROM hierarchical_facts WHERE level = ?`, int(level))
if err != nil {
return nil, fmt.Errorf("list by level: %w", err)
}
defer rows.Close()
return scanFacts(rows)
}
// ListDomains returns distinct domain names.
func (r *FactRepo) ListDomains(ctx context.Context) ([]string, error) {
rows, err := r.db.Query(`SELECT DISTINCT domain FROM hierarchical_facts WHERE domain IS NOT NULL AND domain != ''`)
if err != nil {
return nil, fmt.Errorf("list domains: %w", err)
}
defer rows.Close()
var domains []string
for rows.Next() {
var d string
if err := rows.Scan(&d); err != nil {
return nil, err
}
domains = append(domains, d)
}
return domains, rows.Err()
}
// GetStale returns stale facts, optionally including archived ones.
func (r *FactRepo) GetStale(ctx context.Context, includeArchived bool) ([]*memory.Fact, error) {
var query string
if includeArchived {
query = `SELECT id, content, level, domain, module, code_ref, parent_id,
embedding, ttl_config, created_at, valid_from, valid_until,
is_stale, is_archived, confidence, source, session_id
FROM hierarchical_facts WHERE is_stale = 1`
} else {
query = `SELECT id, content, level, domain, module, code_ref, parent_id,
embedding, ttl_config, created_at, valid_from, valid_until,
is_stale, is_archived, confidence, source, session_id
FROM hierarchical_facts WHERE is_stale = 1 AND is_archived = 0`
}
rows, err := r.db.Query(query)
if err != nil {
return nil, fmt.Errorf("get stale: %w", err)
}
defer rows.Close()
return scanFacts(rows)
}
// Search performs a LIKE-based text search on fact content.
func (r *FactRepo) Search(ctx context.Context, query string, limit int) ([]*memory.Fact, error) {
if limit <= 0 {
limit = 50
}
rows, err := r.db.Query(`SELECT id, content, level, domain, module, code_ref, parent_id,
embedding, ttl_config, created_at, valid_from, valid_until,
is_stale, is_archived, confidence, source, session_id
FROM hierarchical_facts WHERE content LIKE ? LIMIT ?`,
"%"+query+"%", limit)
if err != nil {
return nil, fmt.Errorf("search: %w", err)
}
defer rows.Close()
return scanFacts(rows)
}
// GetExpired returns facts whose TTL has expired.
func (r *FactRepo) GetExpired(ctx context.Context) ([]*memory.Fact, error) {
// Get all facts with TTL config, check expiry in Go.
rows, err := r.db.Query(`SELECT id, content, level, domain, module, code_ref, parent_id,
embedding, ttl_config, created_at, valid_from, valid_until,
is_stale, is_archived, confidence, source, session_id
FROM hierarchical_facts WHERE ttl_config IS NOT NULL AND ttl_config != ''`)
if err != nil {
return nil, fmt.Errorf("get expired: %w", err)
}
defer rows.Close()
all, err := scanFacts(rows)
if err != nil {
return nil, err
}
var expired []*memory.Fact
for _, f := range all {
if f.TTL != nil && f.TTL.IsExpired(f.CreatedAt) {
expired = append(expired, f)
}
}
return expired, nil
}
// RefreshTTL resets the created_at timestamp for a fact (effectively refreshing its TTL).
func (r *FactRepo) RefreshTTL(ctx context.Context, id string) error {
now := time.Now().Format(timeFormat)
result, err := r.db.Exec(`UPDATE hierarchical_facts SET created_at = ? WHERE id = ?`, now, id)
if err != nil {
return fmt.Errorf("refresh ttl: %w", err)
}
n, _ := result.RowsAffected()
if n == 0 {
return fmt.Errorf("fact %s not found", id)
}
return nil
}
// Stats returns aggregate statistics about the fact store.
func (r *FactRepo) Stats(ctx context.Context) (*memory.FactStoreStats, error) {
stats := &memory.FactStoreStats{
ByLevel: make(map[memory.HierLevel]int),
ByDomain: make(map[string]int),
}
// Total facts.
row := r.db.QueryRow(`SELECT COUNT(*) FROM hierarchical_facts`)
if err := row.Scan(&stats.TotalFacts); err != nil {
return nil, err
}
// Stale count.
row = r.db.QueryRow(`SELECT COUNT(*) FROM hierarchical_facts WHERE is_stale = 1`)
if err := row.Scan(&stats.StaleCount); err != nil {
return nil, err
}
// With embeddings.
row = r.db.QueryRow(`SELECT COUNT(*) FROM hierarchical_facts WHERE embedding IS NOT NULL`)
if err := row.Scan(&stats.WithEmbeddings); err != nil {
return nil, err
}
// By level.
rows, err := r.db.Query(`SELECT level, COUNT(*) FROM hierarchical_facts GROUP BY level`)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var level, count int
if err := rows.Scan(&level, &count); err != nil {
return nil, err
}
stats.ByLevel[memory.HierLevel(level)] = count
}
if err := rows.Err(); err != nil {
return nil, err
}
// By domain.
rows2, err := r.db.Query(`SELECT domain, COUNT(*) FROM hierarchical_facts WHERE domain IS NOT NULL AND domain != '' GROUP BY domain`)
if err != nil {
return nil, err
}
defer rows2.Close()
for rows2.Next() {
var domain string
var count int
if err := rows2.Scan(&domain, &count); err != nil {
return nil, err
}
stats.ByDomain[domain] = count
}
// Gene count (Genome Layer).
row = r.db.QueryRow(`SELECT COUNT(*) FROM hierarchical_facts WHERE source = 'genome'`)
if err := row.Scan(&stats.GeneCount); err != nil {
return nil, err
}
// v3.3: Cold count (hit_count=0, created >30 days ago, not gene, not archived).
thirtyDaysAgo := time.Now().AddDate(0, 0, -30).Format(timeFormat)
row = r.db.QueryRow(`SELECT COUNT(*) FROM hierarchical_facts
WHERE hit_count = 0 AND created_at < ? AND source != 'genome' AND is_archived = 0`,
thirtyDaysAgo)
if err := row.Scan(&stats.ColdCount); err != nil {
// Ignore if column doesn't exist yet.
stats.ColdCount = 0
}
return stats, rows2.Err()
}
// --- helpers ---
func scanFact(row *sql.Row) (*memory.Fact, error) {
var f memory.Fact
var levelInt int
var domain, module, codeRef, parentID, sessionID sql.NullString
var embeddingBlob []byte
var ttlJSON sql.NullString
var createdAt, validFrom string
var validUntil sql.NullString
var isStale, isArchived int
err := row.Scan(&f.ID, &f.Content, &levelInt,
&domain, &module, &codeRef, &parentID,
&embeddingBlob, &ttlJSON,
&createdAt, &validFrom, &validUntil,
&isStale, &isArchived, &f.Confidence, &f.Source, &sessionID,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("fact not found")
}
return nil, fmt.Errorf("scan fact: %w", err)
}
f.Level = memory.HierLevel(levelInt)
f.Domain = domain.String
f.Module = module.String
f.CodeRef = codeRef.String
f.ParentID = parentID.String
f.SessionID = sessionID.String
f.IsStale = isStale != 0
f.IsArchived = isArchived != 0
f.IsGene = f.Source == "genome" // Genome Layer: auto-detect from source
f.CreatedAt, _ = time.Parse(timeFormat, createdAt)
f.ValidFrom, _ = time.Parse(timeFormat, validFrom)
f.UpdatedAt = f.CreatedAt // We don't have a separate updated_at column in the DB schema.
if validUntil.Valid {
t, _ := time.Parse(timeFormat, validUntil.String)
f.ValidUntil = &t
}
if len(embeddingBlob) > 0 {
if err := json.Unmarshal(embeddingBlob, &f.Embedding); err != nil {
return nil, fmt.Errorf("decode embedding: %w", err)
}
}
if ttlJSON.Valid && ttlJSON.String != "" {
f.TTL = &memory.TTLConfig{}
if err := json.Unmarshal([]byte(ttlJSON.String), f.TTL); err != nil {
return nil, fmt.Errorf("decode ttl_config: %w", err)
}
}
return &f, nil
}
func scanFacts(rows *sql.Rows) ([]*memory.Fact, error) {
var facts []*memory.Fact
for rows.Next() {
var f memory.Fact
var levelInt int
var domain, module, codeRef, parentID, sessionID sql.NullString
var embeddingBlob []byte
var ttlJSON sql.NullString
var createdAt, validFrom string
var validUntil sql.NullString
var isStale, isArchived int
err := rows.Scan(&f.ID, &f.Content, &levelInt,
&domain, &module, &codeRef, &parentID,
&embeddingBlob, &ttlJSON,
&createdAt, &validFrom, &validUntil,
&isStale, &isArchived, &f.Confidence, &f.Source, &sessionID,
)
if err != nil {
return nil, fmt.Errorf("scan fact row: %w", err)
}
f.Level = memory.HierLevel(levelInt)
f.Domain = domain.String
f.Module = module.String
f.CodeRef = codeRef.String
f.ParentID = parentID.String
f.SessionID = sessionID.String
f.IsStale = isStale != 0
f.IsArchived = isArchived != 0
f.IsGene = f.Source == "genome" // Genome Layer: auto-detect from source
f.CreatedAt, _ = time.Parse(timeFormat, createdAt)
f.ValidFrom, _ = time.Parse(timeFormat, validFrom)
f.UpdatedAt = f.CreatedAt
if validUntil.Valid {
t, _ := time.Parse(timeFormat, validUntil.String)
f.ValidUntil = &t
}
if len(embeddingBlob) > 0 {
if err := json.Unmarshal(embeddingBlob, &f.Embedding); err != nil {
return nil, fmt.Errorf("decode embedding: %w", err)
}
}
if ttlJSON.Valid && ttlJSON.String != "" {
f.TTL = &memory.TTLConfig{}
if err := json.Unmarshal([]byte(ttlJSON.String), f.TTL); err != nil {
return nil, fmt.Errorf("decode ttl_config: %w", err)
}
}
facts = append(facts, &f)
}
return facts, rows.Err()
}
func encodeEmbedding(embedding []float64) ([]byte, error) {
if len(embedding) == 0 {
return nil, nil
}
data, err := json.Marshal(embedding)
if err != nil {
return nil, fmt.Errorf("encode embedding: %w", err)
}
return data, nil
}
func encodeTTL(ttl *memory.TTLConfig) (*string, error) {
if ttl == nil {
return nil, nil
}
data, err := json.Marshal(ttl)
if err != nil {
return nil, fmt.Errorf("encode ttl: %w", err)
}
s := string(data)
return &s, nil
}
func nullStr(s string) sql.NullString {
if s == "" {
return sql.NullString{}
}
return sql.NullString{String: s, Valid: true}
}
func boolToInt(b bool) int {
if b {
return 1
}
return 0
}
// Ensure FactRepo implements memory.FactStore.
var _ memory.FactStore = (*FactRepo)(nil)
// ListGenes returns all genome facts (immutable survival invariants).
func (r *FactRepo) ListGenes(ctx context.Context) ([]*memory.Fact, error) {
rows, err := r.db.Query(`SELECT id, content, level, domain, module, code_ref, parent_id,
embedding, ttl_config, created_at, valid_from, valid_until,
is_stale, is_archived, confidence, source, session_id
FROM hierarchical_facts WHERE source = 'genome' ORDER BY created_at ASC`)
if err != nil {
return nil, fmt.Errorf("list genes: %w", err)
}
defer rows.Close()
return scanFacts(rows)
}
// --- v3.3 Context GC ---
// TouchFact increments hit_count and updates last_accessed_at.
func (r *FactRepo) TouchFact(ctx context.Context, id string) error {
now := time.Now().Format(timeFormat)
_, err := r.db.Exec(`UPDATE hierarchical_facts
SET hit_count = COALESCE(hit_count, 0) + 1, last_accessed_at = ?
WHERE id = ?`, now, id)
if err != nil {
return fmt.Errorf("touch fact: %w", err)
}
return nil
}
// GetColdFacts returns facts with hit_count=0, created >30 days ago.
// Genes (source='genome') and archived facts are excluded.
func (r *FactRepo) GetColdFacts(ctx context.Context, limit int) ([]*memory.Fact, error) {
if limit <= 0 {
limit = 50
}
thirtyDaysAgo := time.Now().AddDate(0, 0, -30).Format(timeFormat)
rows, err := r.db.Query(`SELECT id, content, level, domain, module, code_ref, parent_id,
embedding, ttl_config, created_at, valid_from, valid_until,
is_stale, is_archived, confidence, source, session_id
FROM hierarchical_facts
WHERE COALESCE(hit_count, 0) = 0
AND created_at < ?
AND source != 'genome'
AND is_archived = 0
ORDER BY created_at ASC
LIMIT ?`, thirtyDaysAgo, limit)
if err != nil {
return nil, fmt.Errorf("get cold facts: %w", err)
}
defer rows.Close()
return scanFacts(rows)
}
// CompressFacts archives originals and creates a summary fact.
// Genes (source='genome') are silently skipped.
func (r *FactRepo) CompressFacts(ctx context.Context, ids []string, summary string) (string, error) {
if len(ids) == 0 {
return "", fmt.Errorf("no fact IDs provided")
}
if summary == "" {
return "", fmt.Errorf("summary text is required")
}
// Determine domain from first non-gene fact.
var domain string
var level memory.HierLevel
for _, id := range ids {
f, err := r.Get(ctx, id)
if err != nil {
continue
}
if f.IsGene {
continue // skip genes
}
domain = f.Domain
level = f.Level
break
}
// Archive originals (skip genes).
archived := 0
for _, id := range ids {
f, err := r.Get(ctx, id)
if err != nil || f.IsGene {
continue
}
f.Archive()
if err := r.Update(ctx, f); err != nil {
return "", fmt.Errorf("archive fact %s: %w", id, err)
}
archived++
}
if archived == 0 {
return "", fmt.Errorf("no facts were archived (all genes or not found)")
}
// Create summary fact.
summaryFact := memory.NewFact(summary, level, domain, "")
summaryFact.Source = "consolidation"
if err := r.Add(ctx, summaryFact); err != nil {
return "", fmt.Errorf("create summary fact: %w", err)
}
return summaryFact.ID, nil
}

View file

@ -0,0 +1,293 @@
package sqlite
import (
"context"
"testing"
"time"
"github.com/sentinel-community/gomcp/internal/domain/memory"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newTestFactRepo(t *testing.T) *FactRepo {
t.Helper()
db, err := OpenMemory()
require.NoError(t, err)
t.Cleanup(func() { db.Close() })
repo, err := NewFactRepo(db)
require.NoError(t, err)
return repo
}
func TestFactRepo_Add_Get(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
fact := memory.NewFact("Go is fast", memory.LevelProject, "core", "engine")
fact.Confidence = 0.95
fact.Source = "manual"
fact.CodeRef = "main.go:42"
err := repo.Add(ctx, fact)
require.NoError(t, err)
got, err := repo.Get(ctx, fact.ID)
require.NoError(t, err)
require.NotNil(t, got)
assert.Equal(t, fact.ID, got.ID)
assert.Equal(t, fact.Content, got.Content)
assert.Equal(t, fact.Level, got.Level)
assert.Equal(t, fact.Domain, got.Domain)
assert.Equal(t, fact.Module, got.Module)
assert.Equal(t, fact.CodeRef, got.CodeRef)
assert.InDelta(t, fact.Confidence, got.Confidence, 0.001)
assert.Equal(t, fact.Source, got.Source)
assert.False(t, got.IsStale)
assert.False(t, got.IsArchived)
}
func TestFactRepo_Get_NotFound(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
got, err := repo.Get(ctx, "nonexistent")
assert.Error(t, err)
assert.Nil(t, got)
}
func TestFactRepo_Update(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
fact := memory.NewFact("original", memory.LevelProject, "core", "")
require.NoError(t, repo.Add(ctx, fact))
fact.Content = "updated"
fact.IsStale = true
require.NoError(t, repo.Update(ctx, fact))
got, err := repo.Get(ctx, fact.ID)
require.NoError(t, err)
assert.Equal(t, "updated", got.Content)
assert.True(t, got.IsStale)
}
func TestFactRepo_Delete(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
fact := memory.NewFact("to delete", memory.LevelProject, "", "")
require.NoError(t, repo.Add(ctx, fact))
err := repo.Delete(ctx, fact.ID)
require.NoError(t, err)
got, err := repo.Get(ctx, fact.ID)
assert.Error(t, err)
assert.Nil(t, got)
}
func TestFactRepo_ListByDomain(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
f1 := memory.NewFact("fact1", memory.LevelProject, "backend", "")
f2 := memory.NewFact("fact2", memory.LevelDomain, "backend", "")
f3 := memory.NewFact("fact3", memory.LevelProject, "frontend", "")
f4 := memory.NewFact("stale", memory.LevelProject, "backend", "")
f4.IsStale = true
for _, f := range []*memory.Fact{f1, f2, f3, f4} {
require.NoError(t, repo.Add(ctx, f))
}
// Without stale
facts, err := repo.ListByDomain(ctx, "backend", false)
require.NoError(t, err)
assert.Len(t, facts, 2)
// With stale
facts, err = repo.ListByDomain(ctx, "backend", true)
require.NoError(t, err)
assert.Len(t, facts, 3)
}
func TestFactRepo_ListByLevel(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
f1 := memory.NewFact("f1", memory.LevelProject, "", "")
f2 := memory.NewFact("f2", memory.LevelProject, "", "")
f3 := memory.NewFact("f3", memory.LevelDomain, "", "")
for _, f := range []*memory.Fact{f1, f2, f3} {
require.NoError(t, repo.Add(ctx, f))
}
facts, err := repo.ListByLevel(ctx, memory.LevelProject)
require.NoError(t, err)
assert.Len(t, facts, 2)
}
func TestFactRepo_ListDomains(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
f1 := memory.NewFact("f1", memory.LevelProject, "backend", "")
f2 := memory.NewFact("f2", memory.LevelProject, "frontend", "")
f3 := memory.NewFact("f3", memory.LevelProject, "backend", "")
for _, f := range []*memory.Fact{f1, f2, f3} {
require.NoError(t, repo.Add(ctx, f))
}
domains, err := repo.ListDomains(ctx)
require.NoError(t, err)
assert.Len(t, domains, 2)
assert.Contains(t, domains, "backend")
assert.Contains(t, domains, "frontend")
}
func TestFactRepo_GetStale(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
f1 := memory.NewFact("fresh", memory.LevelProject, "", "")
f2 := memory.NewFact("stale", memory.LevelProject, "", "")
f2.IsStale = true
f3 := memory.NewFact("archived", memory.LevelProject, "", "")
f3.IsStale = true
f3.IsArchived = true
for _, f := range []*memory.Fact{f1, f2, f3} {
require.NoError(t, repo.Add(ctx, f))
}
// Without archived
stale, err := repo.GetStale(ctx, false)
require.NoError(t, err)
assert.Len(t, stale, 1)
// With archived
stale, err = repo.GetStale(ctx, true)
require.NoError(t, err)
assert.Len(t, stale, 2)
}
func TestFactRepo_Search(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
f1 := memory.NewFact("Go concurrency patterns", memory.LevelProject, "", "")
f2 := memory.NewFact("Python is slow", memory.LevelProject, "", "")
f3 := memory.NewFact("Go channels are great", memory.LevelDomain, "", "")
for _, f := range []*memory.Fact{f1, f2, f3} {
require.NoError(t, repo.Add(ctx, f))
}
results, err := repo.Search(ctx, "Go", 10)
require.NoError(t, err)
assert.Len(t, results, 2)
}
func TestFactRepo_GetExpired(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
f1 := memory.NewFact("no ttl", memory.LevelProject, "", "")
f2 := memory.NewFact("expired", memory.LevelProject, "", "")
f2.TTL = &memory.TTLConfig{TTLSeconds: 1, OnExpire: memory.OnExpireMarkStale}
f2.CreatedAt = time.Now().Add(-2 * time.Hour)
f2.ValidFrom = f2.CreatedAt
for _, f := range []*memory.Fact{f1, f2} {
require.NoError(t, repo.Add(ctx, f))
}
expired, err := repo.GetExpired(ctx)
require.NoError(t, err)
assert.Len(t, expired, 1)
assert.Equal(t, f2.ID, expired[0].ID)
}
func TestFactRepo_RefreshTTL(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
f := memory.NewFact("refreshable", memory.LevelProject, "", "")
f.TTL = &memory.TTLConfig{TTLSeconds: 3600, OnExpire: memory.OnExpireMarkStale}
f.CreatedAt = time.Now().Add(-2 * time.Hour)
f.ValidFrom = f.CreatedAt
require.NoError(t, repo.Add(ctx, f))
require.NoError(t, repo.RefreshTTL(ctx, f.ID))
got, err := repo.Get(ctx, f.ID)
require.NoError(t, err)
assert.True(t, got.CreatedAt.After(f.CreatedAt))
}
func TestFactRepo_Stats(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
f1 := memory.NewFact("f1", memory.LevelProject, "backend", "")
f2 := memory.NewFact("f2", memory.LevelDomain, "backend", "")
f2.IsStale = true
f3 := memory.NewFact("f3", memory.LevelProject, "frontend", "")
f3.Embedding = []float64{0.1, 0.2}
for _, f := range []*memory.Fact{f1, f2, f3} {
require.NoError(t, repo.Add(ctx, f))
}
stats, err := repo.Stats(ctx)
require.NoError(t, err)
assert.Equal(t, 3, stats.TotalFacts)
assert.Equal(t, 2, stats.ByLevel[memory.LevelProject])
assert.Equal(t, 1, stats.ByLevel[memory.LevelDomain])
assert.Equal(t, 2, stats.ByDomain["backend"])
assert.Equal(t, 1, stats.ByDomain["frontend"])
assert.Equal(t, 1, stats.StaleCount)
assert.Equal(t, 1, stats.WithEmbeddings)
}
func TestFactRepo_EmbeddingRoundTrip(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
f := memory.NewFact("with embedding", memory.LevelProject, "", "")
f.Embedding = []float64{0.1, 0.2, 0.3, -0.5}
require.NoError(t, repo.Add(ctx, f))
got, err := repo.Get(ctx, f.ID)
require.NoError(t, err)
require.Len(t, got.Embedding, 4)
assert.InDelta(t, 0.1, got.Embedding[0], 0.0001)
assert.InDelta(t, -0.5, got.Embedding[3], 0.0001)
}
func TestFactRepo_TTLConfigRoundTrip(t *testing.T) {
repo := newTestFactRepo(t)
ctx := context.Background()
f := memory.NewFact("with ttl", memory.LevelProject, "", "")
f.TTL = &memory.TTLConfig{
TTLSeconds: 3600,
RefreshTrigger: "main.go",
OnExpire: memory.OnExpireArchive,
}
require.NoError(t, repo.Add(ctx, f))
got, err := repo.Get(ctx, f.ID)
require.NoError(t, err)
require.NotNil(t, got.TTL)
assert.Equal(t, 3600, got.TTL.TTLSeconds)
assert.Equal(t, "main.go", got.TTL.RefreshTrigger)
assert.Equal(t, memory.OnExpireArchive, got.TTL.OnExpire)
}

View file

@ -0,0 +1,128 @@
package sqlite
import (
"context"
"encoding/json"
"fmt"
"time"
)
// InteractionEntry represents a single tool call record in the interaction log.
type InteractionEntry struct {
ID int64 `json:"id"`
ToolName string `json:"tool_name"`
ArgsJSON string `json:"args_json,omitempty"`
Timestamp time.Time `json:"timestamp"`
Processed bool `json:"processed"`
}
// InteractionLogRepo provides crash-safe tool call recording in SQLite.
// Every tool call is INSERT-ed immediately; WAL mode ensures durability
// even on kill -9 / terminal close.
type InteractionLogRepo struct {
db *DB
}
// NewInteractionLogRepo creates the interaction_log table if needed and returns the repo.
func NewInteractionLogRepo(db *DB) (*InteractionLogRepo, error) {
createSQL := `
CREATE TABLE IF NOT EXISTS interaction_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
tool_name TEXT NOT NULL,
args_json TEXT,
timestamp TEXT NOT NULL,
processed INTEGER DEFAULT 0
)`
if _, err := db.Exec(createSQL); err != nil {
return nil, fmt.Errorf("create interaction_log table: %w", err)
}
return &InteractionLogRepo{db: db}, nil
}
// Record inserts a tool call entry. This is designed to be fire-and-forget
// from the middleware — errors are logged but don't break the tool call.
func (r *InteractionLogRepo) Record(ctx context.Context, toolName string, args map[string]interface{}) error {
argsJSON := ""
if len(args) > 0 {
// Only keep string arguments to reduce noise
filtered := make(map[string]string)
for k, v := range args {
if s, ok := v.(string); ok && s != "" {
// Truncate very long values
if len(s) > 200 {
s = s[:200] + "..."
}
filtered[k] = s
}
}
if len(filtered) > 0 {
data, _ := json.Marshal(filtered)
argsJSON = string(data)
}
}
now := time.Now().UTC().Format(time.RFC3339)
_, err := r.db.Exec(
`INSERT INTO interaction_log (tool_name, args_json, timestamp) VALUES (?, ?, ?)`,
toolName, argsJSON, now,
)
return err
}
// GetUnprocessed returns all entries not yet processed, ordered oldest first.
func (r *InteractionLogRepo) GetUnprocessed(ctx context.Context) ([]InteractionEntry, error) {
rows, err := r.db.Query(
`SELECT id, tool_name, args_json, timestamp, processed
FROM interaction_log WHERE processed = 0 ORDER BY id ASC`,
)
if err != nil {
return nil, fmt.Errorf("query unprocessed: %w", err)
}
defer rows.Close()
var entries []InteractionEntry
for rows.Next() {
var e InteractionEntry
var ts string
var proc int
if err := rows.Scan(&e.ID, &e.ToolName, &e.ArgsJSON, &ts, &proc); err != nil {
return nil, fmt.Errorf("scan entry: %w", err)
}
e.Timestamp, _ = time.Parse(time.RFC3339, ts)
e.Processed = proc != 0
entries = append(entries, e)
}
return entries, rows.Err()
}
// MarkProcessed marks entries as processed by their IDs.
func (r *InteractionLogRepo) MarkProcessed(ctx context.Context, ids []int64) error {
if len(ids) == 0 {
return nil
}
for _, id := range ids {
if _, err := r.db.Exec(`UPDATE interaction_log SET processed = 1 WHERE id = ?`, id); err != nil {
return fmt.Errorf("mark processed id=%d: %w", id, err)
}
}
return nil
}
// Count returns the total number of entries and unprocessed count.
func (r *InteractionLogRepo) Count(ctx context.Context) (total int, unprocessed int, err error) {
row := r.db.QueryRow(`SELECT COUNT(*), COALESCE(SUM(CASE WHEN processed=0 THEN 1 ELSE 0 END), 0) FROM interaction_log`)
err = row.Scan(&total, &unprocessed)
return
}
// Prune deletes processed entries older than the given duration.
func (r *InteractionLogRepo) Prune(ctx context.Context, olderThan time.Duration) (int64, error) {
cutoff := time.Now().UTC().Add(-olderThan).Format(time.RFC3339)
result, err := r.db.Exec(
`DELETE FROM interaction_log WHERE processed = 1 AND timestamp <= ?`, cutoff,
)
if err != nil {
return 0, err
}
return result.RowsAffected()
}

View file

@ -0,0 +1,186 @@
package sqlite
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupInteractionRepo(t *testing.T) *InteractionLogRepo {
t.Helper()
db, err := OpenMemory()
require.NoError(t, err)
t.Cleanup(func() { db.Close() })
repo, err := NewInteractionLogRepo(db)
require.NoError(t, err)
return repo
}
func TestNewInteractionLogRepo(t *testing.T) {
repo := setupInteractionRepo(t)
require.NotNil(t, repo)
}
func TestInteractionLogRepo_Record(t *testing.T) {
repo := setupInteractionRepo(t)
ctx := context.Background()
err := repo.Record(ctx, "add_fact", map[string]interface{}{
"content": "test fact",
"level": 0,
})
require.NoError(t, err)
total, unproc, err := repo.Count(ctx)
require.NoError(t, err)
assert.Equal(t, 1, total)
assert.Equal(t, 1, unproc)
}
func TestInteractionLogRepo_Record_EmptyArgs(t *testing.T) {
repo := setupInteractionRepo(t)
ctx := context.Background()
err := repo.Record(ctx, "health", nil)
require.NoError(t, err)
entries, err := repo.GetUnprocessed(ctx)
require.NoError(t, err)
require.Len(t, entries, 1)
assert.Equal(t, "health", entries[0].ToolName)
assert.Empty(t, entries[0].ArgsJSON)
}
func TestInteractionLogRepo_Record_TruncatesLongValues(t *testing.T) {
repo := setupInteractionRepo(t)
ctx := context.Background()
longContent := make([]byte, 500)
for i := range longContent {
longContent[i] = 'x'
}
err := repo.Record(ctx, "add_fact", map[string]interface{}{
"content": string(longContent),
})
require.NoError(t, err)
entries, err := repo.GetUnprocessed(ctx)
require.NoError(t, err)
require.Len(t, entries, 1)
// args_json should contain truncated value (200 chars + "...")
assert.Contains(t, entries[0].ArgsJSON, "...")
assert.Less(t, len(entries[0].ArgsJSON), 300)
}
func TestInteractionLogRepo_GetUnprocessed(t *testing.T) {
repo := setupInteractionRepo(t)
ctx := context.Background()
_ = repo.Record(ctx, "add_fact", map[string]interface{}{"content": "a"})
_ = repo.Record(ctx, "search_facts", map[string]interface{}{"query": "b"})
_ = repo.Record(ctx, "health", nil)
entries, err := repo.GetUnprocessed(ctx)
require.NoError(t, err)
assert.Len(t, entries, 3)
// Ordered by id ASC
assert.Equal(t, "add_fact", entries[0].ToolName)
assert.Equal(t, "search_facts", entries[1].ToolName)
assert.Equal(t, "health", entries[2].ToolName)
}
func TestInteractionLogRepo_MarkProcessed(t *testing.T) {
repo := setupInteractionRepo(t)
ctx := context.Background()
_ = repo.Record(ctx, "tool_a", nil)
_ = repo.Record(ctx, "tool_b", nil)
entries, _ := repo.GetUnprocessed(ctx)
require.Len(t, entries, 2)
// Mark first as processed
err := repo.MarkProcessed(ctx, []int64{entries[0].ID})
require.NoError(t, err)
remaining, _ := repo.GetUnprocessed(ctx)
assert.Len(t, remaining, 1)
assert.Equal(t, "tool_b", remaining[0].ToolName)
total, unproc, _ := repo.Count(ctx)
assert.Equal(t, 2, total)
assert.Equal(t, 1, unproc)
}
func TestInteractionLogRepo_MarkProcessed_Empty(t *testing.T) {
repo := setupInteractionRepo(t)
err := repo.MarkProcessed(context.Background(), nil)
require.NoError(t, err)
}
func TestInteractionLogRepo_Prune(t *testing.T) {
repo := setupInteractionRepo(t)
ctx := context.Background()
// Insert and mark as processed
_ = repo.Record(ctx, "old_tool", nil)
entries, _ := repo.GetUnprocessed(ctx)
_ = repo.MarkProcessed(ctx, []int64{entries[0].ID})
// Prune with 0 duration should delete all processed
deleted, err := repo.Prune(ctx, 0)
require.NoError(t, err)
assert.Equal(t, int64(1), deleted)
total, _, _ := repo.Count(ctx)
assert.Equal(t, 0, total)
}
func TestInteractionLogRepo_Prune_KeepsUnprocessed(t *testing.T) {
repo := setupInteractionRepo(t)
ctx := context.Background()
_ = repo.Record(ctx, "unprocessed_tool", nil)
// Prune should not delete unprocessed entries
deleted, err := repo.Prune(ctx, 0)
require.NoError(t, err)
assert.Equal(t, int64(0), deleted)
total, _, _ := repo.Count(ctx)
assert.Equal(t, 1, total)
}
func TestInteractionLogRepo_Timestamps(t *testing.T) {
repo := setupInteractionRepo(t)
ctx := context.Background()
before := time.Now().UTC()
_ = repo.Record(ctx, "timed_tool", nil)
after := time.Now().UTC()
entries, _ := repo.GetUnprocessed(ctx)
require.Len(t, entries, 1)
ts := entries[0].Timestamp
assert.False(t, ts.Before(before.Add(-time.Second)))
assert.False(t, ts.After(after.Add(time.Second)))
}
func TestInteractionLogRepo_MultipleRecords_Count(t *testing.T) {
repo := setupInteractionRepo(t)
ctx := context.Background()
for i := 0; i < 10; i++ {
_ = repo.Record(ctx, "tool", nil)
}
total, unproc, err := repo.Count(ctx)
require.NoError(t, err)
assert.Equal(t, 10, total)
assert.Equal(t, 10, unproc)
}

View file

@ -0,0 +1,94 @@
package sqlite
import (
"context"
"fmt"
"time"
"github.com/sentinel-community/gomcp/internal/domain/peer"
)
// PeerRepo implements peer.PeerStore using SQLite.
type PeerRepo struct {
db *DB
}
// NewPeerRepo creates a PeerRepo and ensures the peers table exists.
func NewPeerRepo(db *DB) (*PeerRepo, error) {
repo := &PeerRepo{db: db}
if err := repo.migrate(); err != nil {
return nil, fmt.Errorf("peer repo migrate: %w", err)
}
return repo, nil
}
func (r *PeerRepo) migrate() error {
stmt := `CREATE TABLE IF NOT EXISTS peers (
peer_id TEXT PRIMARY KEY,
node_name TEXT NOT NULL,
genome_hash TEXT NOT NULL,
trust_level INTEGER DEFAULT 0,
last_seen TEXT NOT NULL,
fact_count INTEGER DEFAULT 0,
handshake_at TEXT,
created_at TEXT DEFAULT CURRENT_TIMESTAMP
)`
if _, err := r.db.Exec(stmt); err != nil {
return fmt.Errorf("create peers table: %w", err)
}
_, _ = r.db.Exec(`CREATE INDEX IF NOT EXISTS idx_peers_trust ON peers(trust_level)`)
_, _ = r.db.Exec(`CREATE INDEX IF NOT EXISTS idx_peers_last_seen ON peers(last_seen)`)
return nil
}
// SavePeer upserts a peer record.
func (r *PeerRepo) SavePeer(_ context.Context, p *peer.PeerInfo) error {
stmt := `INSERT OR REPLACE INTO peers (peer_id, node_name, genome_hash, trust_level, last_seen, fact_count, handshake_at)
VALUES (?, ?, ?, ?, ?, ?, ?)`
hsAt := ""
if !p.HandshakeAt.IsZero() {
hsAt = p.HandshakeAt.Format(timeFormat)
}
_, err := r.db.Exec(stmt,
p.PeerID, p.NodeName, p.GenomeHash, int(p.Trust),
p.LastSeen.Format(timeFormat), p.FactCount, hsAt,
)
return err
}
// LoadPeers returns all stored peers.
func (r *PeerRepo) LoadPeers(_ context.Context) ([]*peer.PeerInfo, error) {
rows, err := r.db.Query(`SELECT peer_id, node_name, genome_hash, trust_level, last_seen, fact_count, handshake_at FROM peers`)
if err != nil {
return nil, err
}
defer rows.Close()
var peers []*peer.PeerInfo
for rows.Next() {
var p peer.PeerInfo
var trustInt int
var lastSeenStr, handshakeStr string
if err := rows.Scan(&p.PeerID, &p.NodeName, &p.GenomeHash, &trustInt, &lastSeenStr, &p.FactCount, &handshakeStr); err != nil {
return nil, err
}
p.Trust = peer.TrustLevel(trustInt)
p.LastSeen, _ = time.Parse(timeFormat, lastSeenStr)
if handshakeStr != "" {
p.HandshakeAt, _ = time.Parse(timeFormat, handshakeStr)
}
peers = append(peers, &p)
}
return peers, rows.Err()
}
// DeleteExpired removes peers not seen within the given duration.
func (r *PeerRepo) DeleteExpired(_ context.Context, olderThan time.Duration) (int, error) {
cutoff := time.Now().Add(-olderThan).Format(timeFormat)
result, err := r.db.Exec(`DELETE FROM peers WHERE last_seen < ?`, cutoff)
if err != nil {
return 0, err
}
n, _ := result.RowsAffected()
return int(n), nil
}

View file

@ -0,0 +1,366 @@
package sqlite
import (
"database/sql"
"fmt"
"time"
"github.com/sentinel-community/gomcp/internal/domain/soc"
)
// SOCRepo provides SQLite persistence for SOC events, incidents, and sensors.
type SOCRepo struct {
db *DB
}
// NewSOCRepo creates and initializes SOC tables.
func NewSOCRepo(db *DB) (*SOCRepo, error) {
repo := &SOCRepo{db: db}
if err := repo.migrate(); err != nil {
return nil, fmt.Errorf("soc_repo: migrate: %w", err)
}
return repo, nil
}
func (r *SOCRepo) migrate() error {
tables := []string{
`CREATE TABLE IF NOT EXISTS soc_events (
id TEXT PRIMARY KEY,
source TEXT NOT NULL,
sensor_id TEXT NOT NULL DEFAULT '',
severity TEXT NOT NULL,
category TEXT NOT NULL,
subcategory TEXT NOT NULL DEFAULT '',
confidence REAL NOT NULL DEFAULT 0.0,
description TEXT NOT NULL DEFAULT '',
session_id TEXT NOT NULL DEFAULT '',
decision_hash TEXT NOT NULL DEFAULT '',
verdict TEXT NOT NULL DEFAULT 'REVIEW',
timestamp TEXT NOT NULL,
metadata TEXT NOT NULL DEFAULT '{}'
)`,
`CREATE TABLE IF NOT EXISTS soc_incidents (
id TEXT PRIMARY KEY,
status TEXT NOT NULL DEFAULT 'OPEN',
severity TEXT NOT NULL,
title TEXT NOT NULL,
description TEXT NOT NULL DEFAULT '',
event_ids TEXT NOT NULL DEFAULT '[]',
event_count INTEGER NOT NULL DEFAULT 0,
decision_chain_anchor TEXT NOT NULL DEFAULT '',
chain_length INTEGER NOT NULL DEFAULT 0,
correlation_rule TEXT NOT NULL DEFAULT '',
kill_chain_phase TEXT NOT NULL DEFAULT '',
mitre_mapping TEXT NOT NULL DEFAULT '[]',
playbook_applied TEXT NOT NULL DEFAULT '',
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
resolved_at TEXT
)`,
`CREATE TABLE IF NOT EXISTS soc_sensors (
sensor_id TEXT PRIMARY KEY,
sensor_type TEXT NOT NULL,
status TEXT DEFAULT 'UNKNOWN',
first_seen TEXT NOT NULL,
last_seen TEXT NOT NULL,
event_count INTEGER DEFAULT 0,
missed_heartbeats INTEGER DEFAULT 0,
hostname TEXT NOT NULL DEFAULT '',
version TEXT NOT NULL DEFAULT ''
)`,
// Indexes for common queries.
`CREATE INDEX IF NOT EXISTS idx_soc_events_timestamp ON soc_events(timestamp)`,
`CREATE INDEX IF NOT EXISTS idx_soc_events_severity ON soc_events(severity)`,
`CREATE INDEX IF NOT EXISTS idx_soc_events_category ON soc_events(category)`,
`CREATE INDEX IF NOT EXISTS idx_soc_events_sensor ON soc_events(sensor_id)`,
`CREATE INDEX IF NOT EXISTS idx_soc_incidents_status ON soc_incidents(status)`,
`CREATE INDEX IF NOT EXISTS idx_soc_sensors_status ON soc_sensors(status)`,
}
for _, ddl := range tables {
if _, err := r.db.Exec(ddl); err != nil {
return fmt.Errorf("exec %q: %w", ddl[:40], err)
}
}
return nil
}
// === Events ===
// InsertEvent persists a SOC event.
func (r *SOCRepo) InsertEvent(e soc.SOCEvent) error {
_, err := r.db.Exec(
`INSERT INTO soc_events (id, source, sensor_id, severity, category, subcategory,
confidence, description, session_id, decision_hash, verdict, timestamp)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
e.ID, e.Source, e.SensorID, e.Severity, e.Category, e.Subcategory,
e.Confidence, e.Description, e.SessionID, e.DecisionHash, e.Verdict,
e.Timestamp.Format(time.RFC3339Nano),
)
return err
}
// ListEvents returns events ordered by timestamp (newest first), with limit.
func (r *SOCRepo) ListEvents(limit int) ([]soc.SOCEvent, error) {
if limit <= 0 {
limit = 50
}
rows, err := r.db.Query(
`SELECT id, source, sensor_id, severity, category, subcategory,
confidence, description, session_id, decision_hash, verdict, timestamp
FROM soc_events ORDER BY timestamp DESC LIMIT ?`, limit)
if err != nil {
return nil, err
}
defer rows.Close()
return scanEvents(rows)
}
// ListEventsByCategory returns events filtered by category.
func (r *SOCRepo) ListEventsByCategory(category string, limit int) ([]soc.SOCEvent, error) {
if limit <= 0 {
limit = 50
}
rows, err := r.db.Query(
`SELECT id, source, sensor_id, severity, category, subcategory,
confidence, description, session_id, decision_hash, verdict, timestamp
FROM soc_events WHERE category = ? ORDER BY timestamp DESC LIMIT ?`,
category, limit)
if err != nil {
return nil, err
}
defer rows.Close()
return scanEvents(rows)
}
// CountEvents returns total event count.
func (r *SOCRepo) CountEvents() (int, error) {
var count int
err := r.db.QueryRow("SELECT COUNT(*) FROM soc_events").Scan(&count)
return count, err
}
// CountEventsSince returns events in the given time window.
func (r *SOCRepo) CountEventsSince(since time.Time) (int, error) {
var count int
err := r.db.QueryRow(
"SELECT COUNT(*) FROM soc_events WHERE timestamp >= ?",
since.Format(time.RFC3339Nano),
).Scan(&count)
return count, err
}
func scanEvents(rows *sql.Rows) ([]soc.SOCEvent, error) {
var events []soc.SOCEvent
for rows.Next() {
var e soc.SOCEvent
var ts string
err := rows.Scan(&e.ID, &e.Source, &e.SensorID, &e.Severity,
&e.Category, &e.Subcategory, &e.Confidence, &e.Description,
&e.SessionID, &e.DecisionHash, &e.Verdict, &ts)
if err != nil {
return nil, err
}
e.Timestamp, _ = time.Parse(time.RFC3339Nano, ts)
events = append(events, e)
}
return events, rows.Err()
}
// === Incidents ===
// InsertIncident persists a new incident.
func (r *SOCRepo) InsertIncident(inc soc.Incident) error {
_, err := r.db.Exec(
`INSERT INTO soc_incidents (id, status, severity, title, description,
event_count, decision_chain_anchor, chain_length, correlation_rule,
kill_chain_phase, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
inc.ID, inc.Status, inc.Severity, inc.Title, inc.Description,
inc.EventCount, inc.DecisionChainAnchor, inc.ChainLength,
inc.CorrelationRule, inc.KillChainPhase,
inc.CreatedAt.Format(time.RFC3339Nano),
inc.UpdatedAt.Format(time.RFC3339Nano),
)
return err
}
// GetIncident retrieves an incident by ID.
func (r *SOCRepo) GetIncident(id string) (*soc.Incident, error) {
var inc soc.Incident
var createdAt, updatedAt string
var resolvedAt sql.NullString
err := r.db.QueryRow(
`SELECT id, status, severity, title, description, event_count,
decision_chain_anchor, chain_length, correlation_rule,
kill_chain_phase, playbook_applied, created_at, updated_at, resolved_at
FROM soc_incidents WHERE id = ?`, id,
).Scan(&inc.ID, &inc.Status, &inc.Severity, &inc.Title, &inc.Description,
&inc.EventCount, &inc.DecisionChainAnchor, &inc.ChainLength,
&inc.CorrelationRule, &inc.KillChainPhase, &inc.PlaybookApplied,
&createdAt, &updatedAt, &resolvedAt)
if err != nil {
return nil, err
}
inc.CreatedAt, _ = time.Parse(time.RFC3339Nano, createdAt)
inc.UpdatedAt, _ = time.Parse(time.RFC3339Nano, updatedAt)
if resolvedAt.Valid {
t, _ := time.Parse(time.RFC3339Nano, resolvedAt.String)
inc.ResolvedAt = &t
}
return &inc, nil
}
// ListIncidents returns incidents, optionally filtered by status.
func (r *SOCRepo) ListIncidents(status string, limit int) ([]soc.Incident, error) {
if limit <= 0 {
limit = 50
}
var rows *sql.Rows
var err error
if status != "" {
rows, err = r.db.Query(
`SELECT id, status, severity, title, description, event_count,
decision_chain_anchor, chain_length, correlation_rule,
kill_chain_phase, playbook_applied, created_at, updated_at
FROM soc_incidents WHERE status = ? ORDER BY created_at DESC LIMIT ?`,
status, limit)
} else {
rows, err = r.db.Query(
`SELECT id, status, severity, title, description, event_count,
decision_chain_anchor, chain_length, correlation_rule,
kill_chain_phase, playbook_applied, created_at, updated_at
FROM soc_incidents ORDER BY created_at DESC LIMIT ?`, limit)
}
if err != nil {
return nil, err
}
defer rows.Close()
var incidents []soc.Incident
for rows.Next() {
var inc soc.Incident
var createdAt, updatedAt string
err := rows.Scan(&inc.ID, &inc.Status, &inc.Severity, &inc.Title,
&inc.Description, &inc.EventCount, &inc.DecisionChainAnchor,
&inc.ChainLength, &inc.CorrelationRule, &inc.KillChainPhase,
&inc.PlaybookApplied, &createdAt, &updatedAt)
if err != nil {
return nil, err
}
inc.CreatedAt, _ = time.Parse(time.RFC3339Nano, createdAt)
inc.UpdatedAt, _ = time.Parse(time.RFC3339Nano, updatedAt)
incidents = append(incidents, inc)
}
return incidents, rows.Err()
}
// UpdateIncidentStatus updates status (and optionally resolved_at).
func (r *SOCRepo) UpdateIncidentStatus(id string, status soc.IncidentStatus) error {
now := time.Now().Format(time.RFC3339Nano)
if status == soc.StatusResolved || status == soc.StatusFalsePositive {
_, err := r.db.Exec(
`UPDATE soc_incidents SET status = ?, updated_at = ?, resolved_at = ? WHERE id = ?`,
status, now, now, id)
return err
}
_, err := r.db.Exec(
`UPDATE soc_incidents SET status = ?, updated_at = ? WHERE id = ?`,
status, now, id)
return err
}
// CountOpenIncidents returns count of non-resolved incidents.
func (r *SOCRepo) CountOpenIncidents() (int, error) {
var count int
err := r.db.QueryRow(
"SELECT COUNT(*) FROM soc_incidents WHERE status IN ('OPEN', 'INVESTIGATING')",
).Scan(&count)
return count, err
}
// === Sensors ===
// UpsertSensor creates or updates a sensor entry.
func (r *SOCRepo) UpsertSensor(s soc.Sensor) error {
_, err := r.db.Exec(
`INSERT INTO soc_sensors (sensor_id, sensor_type, status, first_seen, last_seen,
event_count, missed_heartbeats, hostname, version)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(sensor_id) DO UPDATE SET
status = excluded.status,
last_seen = excluded.last_seen,
event_count = excluded.event_count,
missed_heartbeats = excluded.missed_heartbeats`,
s.SensorID, s.SensorType, s.Status,
s.FirstSeen.Format(time.RFC3339Nano),
s.LastSeen.Format(time.RFC3339Nano),
s.EventCount, s.MissedHeartbeats, s.Hostname, s.Version,
)
return err
}
// GetSensor retrieves a sensor by ID.
func (r *SOCRepo) GetSensor(id string) (*soc.Sensor, error) {
var s soc.Sensor
var firstSeen, lastSeen string
err := r.db.QueryRow(
`SELECT sensor_id, sensor_type, status, first_seen, last_seen,
event_count, missed_heartbeats, hostname, version
FROM soc_sensors WHERE sensor_id = ?`, id,
).Scan(&s.SensorID, &s.SensorType, &s.Status, &firstSeen, &lastSeen,
&s.EventCount, &s.MissedHeartbeats, &s.Hostname, &s.Version)
if err != nil {
return nil, err
}
s.FirstSeen, _ = time.Parse(time.RFC3339Nano, firstSeen)
s.LastSeen, _ = time.Parse(time.RFC3339Nano, lastSeen)
return &s, nil
}
// ListSensors returns all registered sensors.
func (r *SOCRepo) ListSensors() ([]soc.Sensor, error) {
rows, err := r.db.Query(
`SELECT sensor_id, sensor_type, status, first_seen, last_seen,
event_count, missed_heartbeats, hostname, version
FROM soc_sensors ORDER BY last_seen DESC`)
if err != nil {
return nil, err
}
defer rows.Close()
var sensors []soc.Sensor
for rows.Next() {
var s soc.Sensor
var firstSeen, lastSeen string
err := rows.Scan(&s.SensorID, &s.SensorType, &s.Status,
&firstSeen, &lastSeen, &s.EventCount, &s.MissedHeartbeats,
&s.Hostname, &s.Version)
if err != nil {
return nil, err
}
s.FirstSeen, _ = time.Parse(time.RFC3339Nano, firstSeen)
s.LastSeen, _ = time.Parse(time.RFC3339Nano, lastSeen)
sensors = append(sensors, s)
}
return sensors, rows.Err()
}
// CountSensorsByStatus returns sensor count grouped by status.
func (r *SOCRepo) CountSensorsByStatus() (map[soc.SensorStatus]int, error) {
rows, err := r.db.Query("SELECT status, COUNT(*) FROM soc_sensors GROUP BY status")
if err != nil {
return nil, err
}
defer rows.Close()
result := make(map[soc.SensorStatus]int)
for rows.Next() {
var status soc.SensorStatus
var count int
if err := rows.Scan(&status, &count); err != nil {
return nil, err
}
result[status] = count
}
return result, rows.Err()
}

View file

@ -0,0 +1,263 @@
package sqlite
import (
"testing"
"time"
"github.com/sentinel-community/gomcp/internal/domain/soc"
)
func setupSOCRepo(t *testing.T) *SOCRepo {
t.Helper()
db, err := OpenMemory()
if err != nil {
t.Fatalf("open memory db: %v", err)
}
t.Cleanup(func() { db.Close() })
repo, err := NewSOCRepo(db)
if err != nil {
t.Fatalf("new soc repo: %v", err)
}
return repo
}
// === Event Tests ===
func TestInsertAndListEvents(t *testing.T) {
repo := setupSOCRepo(t)
e1 := soc.NewSOCEvent(soc.SourceSentinelCore, soc.SeverityHigh, "jailbreak", "Jailbreak detected").
WithSensor("core-01").WithConfidence(0.95)
e2 := soc.NewSOCEvent(soc.SourceShield, soc.SeverityMedium, "network_block", "Connection blocked").
WithSensor("shield-01")
if err := repo.InsertEvent(e1); err != nil {
t.Fatalf("insert e1: %v", err)
}
if err := repo.InsertEvent(e2); err != nil {
t.Fatalf("insert e2: %v", err)
}
events, err := repo.ListEvents(10)
if err != nil {
t.Fatalf("list events: %v", err)
}
if len(events) != 2 {
t.Errorf("expected 2 events, got %d", len(events))
}
count, err := repo.CountEvents()
if err != nil {
t.Fatalf("count: %v", err)
}
if count != 2 {
t.Errorf("expected count 2, got %d", count)
}
}
func TestListEventsByCategory(t *testing.T) {
repo := setupSOCRepo(t)
e1 := soc.NewSOCEvent(soc.SourceSentinelCore, soc.SeverityHigh, "jailbreak", "test")
repo.InsertEvent(e1)
time.Sleep(time.Millisecond)
e2 := soc.NewSOCEvent(soc.SourceSentinelCore, soc.SeverityMedium, "injection", "test")
repo.InsertEvent(e2)
time.Sleep(time.Millisecond)
e3 := soc.NewSOCEvent(soc.SourceSentinelCore, soc.SeverityLow, "jailbreak", "test2")
repo.InsertEvent(e3)
events, err := repo.ListEventsByCategory("jailbreak", 10)
if err != nil {
t.Fatalf("list by category: %v", err)
}
if len(events) != 2 {
t.Errorf("expected 2 jailbreak events, got %d", len(events))
}
}
// === Incident Tests ===
func TestInsertAndGetIncident(t *testing.T) {
repo := setupSOCRepo(t)
inc := soc.NewIncident("Multi-stage Jailbreak", soc.SeverityCritical, "jailbreak_chain")
inc.SetAnchor("abc123", 5)
if err := repo.InsertIncident(inc); err != nil {
t.Fatalf("insert incident: %v", err)
}
got, err := repo.GetIncident(inc.ID)
if err != nil {
t.Fatalf("get incident: %v", err)
}
if got.ID != inc.ID {
t.Errorf("ID mismatch: got %s, want %s", got.ID, inc.ID)
}
if got.DecisionChainAnchor != "abc123" {
t.Errorf("anchor mismatch: got %s", got.DecisionChainAnchor)
}
if got.ChainLength != 5 {
t.Errorf("chain length: got %d, want 5", got.ChainLength)
}
}
func TestUpdateIncidentStatus(t *testing.T) {
repo := setupSOCRepo(t)
inc := soc.NewIncident("Test", soc.SeverityHigh, "test_rule")
repo.InsertIncident(inc)
if err := repo.UpdateIncidentStatus(inc.ID, soc.StatusResolved); err != nil {
t.Fatalf("update status: %v", err)
}
got, err := repo.GetIncident(inc.ID)
if err != nil {
t.Fatalf("get after update: %v", err)
}
if got.Status != soc.StatusResolved {
t.Errorf("expected RESOLVED, got %s", got.Status)
}
if got.ResolvedAt == nil {
t.Error("resolved_at should be set")
}
}
func TestListIncidentsWithFilter(t *testing.T) {
repo := setupSOCRepo(t)
inc1 := soc.NewIncident("Open Inc", soc.SeverityHigh, "rule1")
inc2 := soc.NewIncident("Resolved Inc", soc.SeverityMedium, "rule2")
repo.InsertIncident(inc1)
repo.InsertIncident(inc2)
repo.UpdateIncidentStatus(inc2.ID, soc.StatusResolved)
// List OPEN only
open, err := repo.ListIncidents("OPEN", 10)
if err != nil {
t.Fatalf("list open: %v", err)
}
if len(open) != 1 {
t.Errorf("expected 1 open incident, got %d", len(open))
}
// List all
all, err := repo.ListIncidents("", 10)
if err != nil {
t.Fatalf("list all: %v", err)
}
if len(all) != 2 {
t.Errorf("expected 2 total incidents, got %d", len(all))
}
}
func TestCountOpenIncidents(t *testing.T) {
repo := setupSOCRepo(t)
inc1 := soc.NewIncident("Open", soc.SeverityHigh, "r1")
inc2 := soc.NewIncident("Investigating", soc.SeverityMedium, "r2")
inc3 := soc.NewIncident("Resolved", soc.SeverityLow, "r3")
repo.InsertIncident(inc1)
repo.InsertIncident(inc2)
repo.InsertIncident(inc3)
repo.UpdateIncidentStatus(inc2.ID, soc.StatusInvestigating)
repo.UpdateIncidentStatus(inc3.ID, soc.StatusResolved)
count, err := repo.CountOpenIncidents()
if err != nil {
t.Fatalf("count open: %v", err)
}
if count != 2 {
t.Errorf("expected 2 open (OPEN+INVESTIGATING), got %d", count)
}
}
// === Sensor Tests ===
func TestUpsertAndGetSensor(t *testing.T) {
repo := setupSOCRepo(t)
s := soc.NewSensor("core-01", soc.SensorTypeSentinelCore)
s.RecordEvent()
s.RecordEvent()
s.RecordEvent() // Should be HEALTHY
if err := repo.UpsertSensor(s); err != nil {
t.Fatalf("upsert: %v", err)
}
got, err := repo.GetSensor("core-01")
if err != nil {
t.Fatalf("get sensor: %v", err)
}
if got.Status != soc.SensorStatusHealthy {
t.Errorf("expected HEALTHY, got %s", got.Status)
}
if got.EventCount != 3 {
t.Errorf("expected 3 events, got %d", got.EventCount)
}
}
func TestSensorUpsertUpdate(t *testing.T) {
repo := setupSOCRepo(t)
s := soc.NewSensor("shield-01", soc.SensorTypeShield)
repo.UpsertSensor(s)
// Update with new status
s.RecordEvent()
s.RecordEvent()
s.RecordEvent()
repo.UpsertSensor(s)
got, err := repo.GetSensor("shield-01")
if err != nil {
t.Fatalf("get sensor: %v", err)
}
if got.EventCount != 3 {
t.Errorf("upsert should update event_count, got %d", got.EventCount)
}
}
func TestListSensors(t *testing.T) {
repo := setupSOCRepo(t)
repo.UpsertSensor(soc.NewSensor("core-01", soc.SensorTypeSentinelCore))
repo.UpsertSensor(soc.NewSensor("shield-01", soc.SensorTypeShield))
sensors, err := repo.ListSensors()
if err != nil {
t.Fatalf("list: %v", err)
}
if len(sensors) != 2 {
t.Errorf("expected 2, got %d", len(sensors))
}
}
func TestCountSensorsByStatus(t *testing.T) {
repo := setupSOCRepo(t)
s1 := soc.NewSensor("core-01", soc.SensorTypeSentinelCore)
s1.RecordEvent()
s1.RecordEvent()
s1.RecordEvent() // HEALTHY
s2 := soc.NewSensor("shield-01", soc.SensorTypeShield) // UNKNOWN
repo.UpsertSensor(s1)
repo.UpsertSensor(s2)
counts, err := repo.CountSensorsByStatus()
if err != nil {
t.Fatalf("count by status: %v", err)
}
if counts[soc.SensorStatusHealthy] != 1 {
t.Errorf("expected 1 HEALTHY, got %d", counts[soc.SensorStatusHealthy])
}
if counts[soc.SensorStatusUnknown] != 1 {
t.Errorf("expected 1 UNKNOWN, got %d", counts[soc.SensorStatusUnknown])
}
}

View file

@ -0,0 +1,215 @@
package sqlite
import (
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"encoding/json"
"fmt"
"time"
"github.com/sentinel-community/gomcp/internal/domain/session"
)
// StateRepo implements session.StateStore using SQLite.
// Compatible with memory_bridge.db schema (states + audit_log).
// NOTE: The Python version uses AES-256-GCM encryption on the data blob.
// This Go implementation stores plaintext JSON for now — encryption
// can be layered on top via a decorator if needed.
type StateRepo struct {
db *DB
}
// NewStateRepo creates a StateRepo and ensures the schema exists.
func NewStateRepo(db *DB) (*StateRepo, error) {
repo := &StateRepo{db: db}
if err := repo.migrate(); err != nil {
return nil, fmt.Errorf("state repo migrate: %w", err)
}
return repo, nil
}
func (r *StateRepo) migrate() error {
stmts := []string{
`CREATE TABLE IF NOT EXISTS states (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
version INTEGER NOT NULL,
timestamp TEXT NOT NULL,
checksum TEXT NOT NULL,
data BLOB NOT NULL,
nonce BLOB,
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
UNIQUE(session_id, version)
)`,
`CREATE INDEX IF NOT EXISTS idx_session_id ON states(session_id)`,
`CREATE INDEX IF NOT EXISTS idx_timestamp ON states(timestamp)`,
`CREATE TABLE IF NOT EXISTS audit_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
action TEXT NOT NULL,
version INTEGER,
timestamp TEXT NOT NULL,
details TEXT,
created_at TEXT DEFAULT CURRENT_TIMESTAMP
)`,
`CREATE INDEX IF NOT EXISTS idx_audit_session ON audit_log(session_id)`,
`CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON audit_log(timestamp)`,
}
for _, s := range stmts {
if _, err := r.db.Exec(s); err != nil {
return fmt.Errorf("exec migration: %w", err)
}
}
return nil
}
// Save persists a cognitive state vector snapshot.
func (r *StateRepo) Save(ctx context.Context, state *session.CognitiveStateVector, checksum string) error {
data, err := json.Marshal(state)
if err != nil {
return fmt.Errorf("marshal state: %w", err)
}
// Verify checksum if provided, or compute one.
if checksum == "" {
h := sha256.Sum256(data)
checksum = hex.EncodeToString(h[:])
}
now := time.Now().Format(timeFormat)
// Determine action for audit log.
var action string
var existingCount int
row := r.db.QueryRow(`SELECT COUNT(*) FROM states WHERE session_id = ?`, state.SessionID)
if err := row.Scan(&existingCount); err != nil {
return fmt.Errorf("count existing: %w", err)
}
if existingCount == 0 {
action = "create"
} else {
action = "update"
}
_, err = r.db.Exec(`INSERT INTO states (session_id, version, timestamp, checksum, data)
VALUES (?, ?, ?, ?, ?)`,
state.SessionID, state.Version, now, checksum, data,
)
if err != nil {
return fmt.Errorf("insert state: %w", err)
}
// Write audit log entry.
_, err = r.db.Exec(`INSERT INTO audit_log (session_id, action, version, timestamp, details)
VALUES (?, ?, ?, ?, ?)`,
state.SessionID, action, state.Version, now,
fmt.Sprintf("%s session %s v%d", action, state.SessionID, state.Version),
)
if err != nil {
return fmt.Errorf("insert audit: %w", err)
}
return nil
}
// Load retrieves a cognitive state vector. If version is nil, loads the latest.
func (r *StateRepo) Load(ctx context.Context, sessionID string, version *int) (*session.CognitiveStateVector, string, error) {
var row *sql.Row
if version != nil {
row = r.db.QueryRow(`SELECT data, checksum FROM states WHERE session_id = ? AND version = ?`,
sessionID, *version)
} else {
row = r.db.QueryRow(`SELECT data, checksum FROM states WHERE session_id = ? ORDER BY version DESC LIMIT 1`,
sessionID)
}
var data []byte
var checksum string
if err := row.Scan(&data, &checksum); err != nil {
if err == sql.ErrNoRows {
return nil, "", fmt.Errorf("session %s not found", sessionID)
}
return nil, "", fmt.Errorf("scan state: %w", err)
}
var state session.CognitiveStateVector
if err := json.Unmarshal(data, &state); err != nil {
return nil, "", fmt.Errorf("unmarshal state: %w", err)
}
return &state, checksum, nil
}
// ListSessions returns metadata about all persisted sessions.
func (r *StateRepo) ListSessions(ctx context.Context) ([]session.SessionInfo, error) {
rows, err := r.db.Query(`SELECT session_id, MAX(version) as version, MAX(timestamp) as updated_at
FROM states GROUP BY session_id ORDER BY updated_at DESC`)
if err != nil {
return nil, fmt.Errorf("list sessions: %w", err)
}
defer rows.Close()
var sessions []session.SessionInfo
for rows.Next() {
var info session.SessionInfo
var updatedAt string
if err := rows.Scan(&info.SessionID, &info.Version, &updatedAt); err != nil {
return nil, fmt.Errorf("scan session info: %w", err)
}
info.UpdatedAt, _ = time.Parse(timeFormat, updatedAt)
sessions = append(sessions, info)
}
return sessions, rows.Err()
}
// DeleteSession removes all versions of a session. Returns the number of deleted rows.
func (r *StateRepo) DeleteSession(ctx context.Context, sessionID string) (int, error) {
now := time.Now().Format(timeFormat)
result, err := r.db.Exec(`DELETE FROM states WHERE session_id = ?`, sessionID)
if err != nil {
return 0, fmt.Errorf("delete session: %w", err)
}
n, _ := result.RowsAffected()
// Audit log.
_, _ = r.db.Exec(`INSERT INTO audit_log (session_id, action, timestamp, details)
VALUES (?, 'delete', ?, ?)`,
sessionID, now, fmt.Sprintf("deleted %d versions of session %s", n, sessionID),
)
return int(n), nil
}
// GetAuditLog returns the audit log for a session.
func (r *StateRepo) GetAuditLog(ctx context.Context, sessionID string, limit int) ([]session.AuditEntry, error) {
if limit <= 0 {
limit = 50
}
rows, err := r.db.Query(`SELECT session_id, action, version, timestamp, details
FROM audit_log WHERE session_id = ? ORDER BY id DESC LIMIT ?`,
sessionID, limit)
if err != nil {
return nil, fmt.Errorf("get audit log: %w", err)
}
defer rows.Close()
var entries []session.AuditEntry
for rows.Next() {
var e session.AuditEntry
var version sql.NullInt64
if err := rows.Scan(&e.SessionID, &e.Action, &version, &e.Timestamp, &e.Details); err != nil {
return nil, fmt.Errorf("scan audit entry: %w", err)
}
if version.Valid {
e.Version = int(version.Int64)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
// Ensure StateRepo implements session.StateStore.
var _ session.StateStore = (*StateRepo)(nil)

View file

@ -0,0 +1,163 @@
package sqlite
import (
"context"
"testing"
"github.com/sentinel-community/gomcp/internal/domain/session"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newTestStateRepo(t *testing.T) *StateRepo {
t.Helper()
db, err := OpenMemory()
require.NoError(t, err)
t.Cleanup(func() { db.Close() })
repo, err := NewStateRepo(db)
require.NoError(t, err)
return repo
}
func TestStateRepo_Save_Load(t *testing.T) {
repo := newTestStateRepo(t)
ctx := context.Background()
csv := session.NewCognitiveStateVector("test-session")
csv.SetGoal("Build GoMCP", 0.3)
csv.AddHypothesis("SQLite is fast enough")
csv.AddDecision("Use mcp-go", "mature lib", []string{"custom"})
csv.AddFact("Go 1.25", "requirement", 1.0)
checksum := csv.Checksum()
err := repo.Save(ctx, csv, checksum)
require.NoError(t, err)
loaded, gotChecksum, err := repo.Load(ctx, "test-session", nil)
require.NoError(t, err)
require.NotNil(t, loaded)
assert.Equal(t, csv.SessionID, loaded.SessionID)
assert.Equal(t, csv.Version, loaded.Version)
assert.Equal(t, checksum, gotChecksum)
require.NotNil(t, loaded.PrimaryGoal)
assert.Equal(t, "Build GoMCP", loaded.PrimaryGoal.Description)
assert.Len(t, loaded.Hypotheses, 1)
assert.Len(t, loaded.Decisions, 1)
assert.Len(t, loaded.Facts, 1)
}
func TestStateRepo_Save_Versioning(t *testing.T) {
repo := newTestStateRepo(t)
ctx := context.Background()
csv := session.NewCognitiveStateVector("s1")
csv.SetGoal("v1", 0.1)
require.NoError(t, repo.Save(ctx, csv, csv.Checksum()))
csv.BumpVersion()
csv.SetGoal("v2", 0.5)
require.NoError(t, repo.Save(ctx, csv, csv.Checksum()))
// Load latest
loaded, _, err := repo.Load(ctx, "s1", nil)
require.NoError(t, err)
assert.Equal(t, 2, loaded.Version)
assert.Equal(t, "v2", loaded.PrimaryGoal.Description)
// Load specific version
v := 1
loaded, _, err = repo.Load(ctx, "s1", &v)
require.NoError(t, err)
assert.Equal(t, 1, loaded.Version)
assert.Equal(t, "v1", loaded.PrimaryGoal.Description)
}
func TestStateRepo_Load_NotFound(t *testing.T) {
repo := newTestStateRepo(t)
ctx := context.Background()
loaded, _, err := repo.Load(ctx, "nonexistent", nil)
assert.Error(t, err)
assert.Nil(t, loaded)
}
func TestStateRepo_ListSessions(t *testing.T) {
repo := newTestStateRepo(t)
ctx := context.Background()
s1 := session.NewCognitiveStateVector("session-1")
s2 := session.NewCognitiveStateVector("session-2")
require.NoError(t, repo.Save(ctx, s1, s1.Checksum()))
require.NoError(t, repo.Save(ctx, s2, s2.Checksum()))
sessions, err := repo.ListSessions(ctx)
require.NoError(t, err)
assert.Len(t, sessions, 2)
}
func TestStateRepo_DeleteSession(t *testing.T) {
repo := newTestStateRepo(t)
ctx := context.Background()
csv := session.NewCognitiveStateVector("to-delete")
require.NoError(t, repo.Save(ctx, csv, csv.Checksum()))
csv.BumpVersion()
require.NoError(t, repo.Save(ctx, csv, csv.Checksum()))
count, err := repo.DeleteSession(ctx, "to-delete")
require.NoError(t, err)
assert.Equal(t, 2, count)
loaded, _, err := repo.Load(ctx, "to-delete", nil)
assert.Error(t, err)
assert.Nil(t, loaded)
}
func TestStateRepo_AuditLog(t *testing.T) {
repo := newTestStateRepo(t)
ctx := context.Background()
csv := session.NewCognitiveStateVector("audited")
require.NoError(t, repo.Save(ctx, csv, csv.Checksum()))
csv.BumpVersion()
require.NoError(t, repo.Save(ctx, csv, csv.Checksum()))
log, err := repo.GetAuditLog(ctx, "audited", 10)
require.NoError(t, err)
assert.GreaterOrEqual(t, len(log), 2)
assert.Equal(t, "audited", log[0].SessionID)
}
func TestStateRepo_ComplexState_RoundTrip(t *testing.T) {
repo := newTestStateRepo(t)
ctx := context.Background()
csv := session.NewCognitiveStateVector("complex")
csv.SetGoal("Build full system", 0.7)
csv.AddHypothesis("H1")
csv.AddHypothesis("H2")
csv.AddDecision("D1", "R1", []string{"A1", "A2"})
csv.AddDecision("D2", "R2", nil)
csv.AddFact("F1", "requirement", 0.9)
csv.AddFact("F2", "decision", 1.0)
csv.AddFact("F3", "context", 0.5)
csv.OpenQuestions = []string{"Q1", "Q2", "Q3"}
csv.ConfidenceMap["area1"] = 0.8
csv.ConfidenceMap["area2"] = 0.3
require.NoError(t, repo.Save(ctx, csv, csv.Checksum()))
loaded, _, err := repo.Load(ctx, "complex", nil)
require.NoError(t, err)
assert.Equal(t, "Build full system", loaded.PrimaryGoal.Description)
assert.Len(t, loaded.Hypotheses, 2)
assert.Len(t, loaded.Decisions, 2)
assert.Len(t, loaded.Facts, 3)
assert.Len(t, loaded.OpenQuestions, 3)
assert.InDelta(t, 0.8, loaded.ConfidenceMap["area1"], 0.001)
}

View file

@ -0,0 +1,133 @@
package sqlite
import (
"context"
"fmt"
"time"
"github.com/sentinel-community/gomcp/internal/domain/synapse"
)
// SynapseRepo implements synapse.SynapseStore using SQLite.
type SynapseRepo struct {
db *DB
}
// NewSynapseRepo creates a SynapseRepo (table created by FactRepo migration v3.3).
func NewSynapseRepo(db *DB) *SynapseRepo {
return &SynapseRepo{db: db}
}
// Create inserts a new PENDING synapse.
func (r *SynapseRepo) Create(ctx context.Context, factIDA, factIDB string, confidence float64) (int64, error) {
result, err := r.db.Exec(
`INSERT INTO synapses (fact_id_a, fact_id_b, confidence, status, created_at)
VALUES (?, ?, ?, 'PENDING', ?)`,
factIDA, factIDB, confidence, time.Now().Format(timeFormat))
if err != nil {
return 0, fmt.Errorf("create synapse: %w", err)
}
id, _ := result.LastInsertId()
return id, nil
}
// ListPending returns synapses with status PENDING.
func (r *SynapseRepo) ListPending(ctx context.Context, limit int) ([]*synapse.Synapse, error) {
if limit <= 0 {
limit = 50
}
rows, err := r.db.Query(
`SELECT id, fact_id_a, fact_id_b, confidence, status, created_at
FROM synapses WHERE status = 'PENDING' ORDER BY confidence DESC LIMIT ?`, limit)
if err != nil {
return nil, fmt.Errorf("list pending: %w", err)
}
defer rows.Close()
return scanSynapses(rows)
}
// Accept transitions a synapse to VERIFIED.
func (r *SynapseRepo) Accept(ctx context.Context, id int64) error {
result, err := r.db.Exec(`UPDATE synapses SET status = 'VERIFIED' WHERE id = ? AND status = 'PENDING'`, id)
if err != nil {
return fmt.Errorf("accept synapse: %w", err)
}
n, _ := result.RowsAffected()
if n == 0 {
return fmt.Errorf("synapse %d not found or not PENDING", id)
}
return nil
}
// Reject transitions a synapse to REJECTED.
func (r *SynapseRepo) Reject(ctx context.Context, id int64) error {
result, err := r.db.Exec(`UPDATE synapses SET status = 'REJECTED' WHERE id = ? AND status = 'PENDING'`, id)
if err != nil {
return fmt.Errorf("reject synapse: %w", err)
}
n, _ := result.RowsAffected()
if n == 0 {
return fmt.Errorf("synapse %d not found or not PENDING", id)
}
return nil
}
// ListVerified returns all VERIFIED synapses.
func (r *SynapseRepo) ListVerified(ctx context.Context) ([]*synapse.Synapse, error) {
rows, err := r.db.Query(
`SELECT id, fact_id_a, fact_id_b, confidence, status, created_at
FROM synapses WHERE status = 'VERIFIED' ORDER BY confidence DESC`)
if err != nil {
return nil, fmt.Errorf("list verified: %w", err)
}
defer rows.Close()
return scanSynapses(rows)
}
// Count returns synapse counts by status.
func (r *SynapseRepo) Count(ctx context.Context) (pending, verified, rejected int, err error) {
row := r.db.QueryRow(`SELECT
COALESCE(SUM(CASE WHEN status='PENDING' THEN 1 ELSE 0 END), 0),
COALESCE(SUM(CASE WHEN status='VERIFIED' THEN 1 ELSE 0 END), 0),
COALESCE(SUM(CASE WHEN status='REJECTED' THEN 1 ELSE 0 END), 0)
FROM synapses`)
err = row.Scan(&pending, &verified, &rejected)
if err != nil {
return 0, 0, 0, fmt.Errorf("count synapses: %w", err)
}
return
}
// Exists checks if a synapse exists between two facts (either direction).
func (r *SynapseRepo) Exists(ctx context.Context, factIDA, factIDB string) (bool, error) {
var count int
row := r.db.QueryRow(
`SELECT COUNT(*) FROM synapses
WHERE (fact_id_a = ? AND fact_id_b = ?) OR (fact_id_a = ? AND fact_id_b = ?)`,
factIDA, factIDB, factIDB, factIDA)
if err := row.Scan(&count); err != nil {
return false, fmt.Errorf("exists synapse: %w", err)
}
return count > 0, nil
}
// Ensure SynapseRepo implements synapse.SynapseStore.
var _ synapse.SynapseStore = (*SynapseRepo)(nil)
func scanSynapses(rows interface {
Next() bool
Scan(...any) error
}) ([]*synapse.Synapse, error) {
var result []*synapse.Synapse
for rows.Next() {
var s synapse.Synapse
var status, createdAt string
if err := rows.Scan(&s.ID, &s.FactIDA, &s.FactIDB, &s.Confidence, &status, &createdAt); err != nil {
return nil, fmt.Errorf("scan synapse: %w", err)
}
s.Status = synapse.Status(status)
s.CreatedAt, _ = time.Parse(timeFormat, createdAt)
result = append(result, &s)
}
return result, nil
}

View file

@ -0,0 +1,164 @@
package sqlite_test
import (
"context"
"testing"
"github.com/sentinel-community/gomcp/internal/domain/memory"
"github.com/sentinel-community/gomcp/internal/infrastructure/sqlite"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupSynapseTest(t *testing.T) (*sqlite.SynapseRepo, *sqlite.FactRepo) {
t.Helper()
db, err := sqlite.OpenMemory()
require.NoError(t, err)
t.Cleanup(func() { db.Close() })
// FactRepo migration creates synapses table.
factRepo, err := sqlite.NewFactRepo(db)
require.NoError(t, err)
synapseRepo := sqlite.NewSynapseRepo(db)
return synapseRepo, factRepo
}
func TestSynapseRepo_CreateAndListPending(t *testing.T) {
repo, factRepo := setupSynapseTest(t)
ctx := context.Background()
// Create two facts to link.
f1 := memory.NewFact("Architecture overview", memory.LevelDomain, "arch", "")
f2 := memory.NewFact("Security module design", memory.LevelDomain, "security", "")
require.NoError(t, factRepo.Add(ctx, f1))
require.NoError(t, factRepo.Add(ctx, f2))
// Create synapse.
id, err := repo.Create(ctx, f1.ID, f2.ID, 0.92)
require.NoError(t, err)
assert.Greater(t, id, int64(0))
// List pending.
pending, err := repo.ListPending(ctx, 10)
require.NoError(t, err)
require.Len(t, pending, 1)
assert.Equal(t, f1.ID, pending[0].FactIDA)
assert.Equal(t, f2.ID, pending[0].FactIDB)
assert.InDelta(t, 0.92, pending[0].Confidence, 0.01)
assert.Equal(t, "PENDING", string(pending[0].Status))
}
func TestSynapseRepo_AcceptAndListVerified(t *testing.T) {
repo, factRepo := setupSynapseTest(t)
ctx := context.Background()
f1 := memory.NewFact("fact A", memory.LevelModule, "test", "")
f2 := memory.NewFact("fact B", memory.LevelModule, "test", "")
require.NoError(t, factRepo.Add(ctx, f1))
require.NoError(t, factRepo.Add(ctx, f2))
id, err := repo.Create(ctx, f1.ID, f2.ID, 0.88)
require.NoError(t, err)
// Accept.
require.NoError(t, repo.Accept(ctx, id))
// Should no longer be in pending.
pending, _ := repo.ListPending(ctx, 10)
assert.Empty(t, pending)
// Should be in verified.
verified, err := repo.ListVerified(ctx)
require.NoError(t, err)
require.Len(t, verified, 1)
assert.Equal(t, "VERIFIED", string(verified[0].Status))
}
func TestSynapseRepo_Reject(t *testing.T) {
repo, factRepo := setupSynapseTest(t)
ctx := context.Background()
f1 := memory.NewFact("fact X", memory.LevelProject, "", "")
f2 := memory.NewFact("fact Y", memory.LevelProject, "", "")
require.NoError(t, factRepo.Add(ctx, f1))
require.NoError(t, factRepo.Add(ctx, f2))
id, err := repo.Create(ctx, f1.ID, f2.ID, 0.50)
require.NoError(t, err)
require.NoError(t, repo.Reject(ctx, id))
pending, _ := repo.ListPending(ctx, 10)
assert.Empty(t, pending)
verified, _ := repo.ListVerified(ctx)
assert.Empty(t, verified)
}
func TestSynapseRepo_Count(t *testing.T) {
repo, factRepo := setupSynapseTest(t)
ctx := context.Background()
f1 := memory.NewFact("a", memory.LevelProject, "", "")
f2 := memory.NewFact("b", memory.LevelProject, "", "")
f3 := memory.NewFact("c", memory.LevelProject, "", "")
require.NoError(t, factRepo.Add(ctx, f1))
require.NoError(t, factRepo.Add(ctx, f2))
require.NoError(t, factRepo.Add(ctx, f3))
id1, _ := repo.Create(ctx, f1.ID, f2.ID, 0.90)
id2, _ := repo.Create(ctx, f1.ID, f3.ID, 0.85)
_, _ = repo.Create(ctx, f2.ID, f3.ID, 0.40)
_ = repo.Accept(ctx, id1)
_ = repo.Reject(ctx, id2)
pending, verified, rejected, err := repo.Count(ctx)
require.NoError(t, err)
assert.Equal(t, 1, pending)
assert.Equal(t, 1, verified)
assert.Equal(t, 1, rejected)
}
func TestSynapseRepo_Exists(t *testing.T) {
repo, factRepo := setupSynapseTest(t)
ctx := context.Background()
f1 := memory.NewFact("p", memory.LevelProject, "", "")
f2 := memory.NewFact("q", memory.LevelProject, "", "")
require.NoError(t, factRepo.Add(ctx, f1))
require.NoError(t, factRepo.Add(ctx, f2))
// No synapse yet.
exists, err := repo.Exists(ctx, f1.ID, f2.ID)
require.NoError(t, err)
assert.False(t, exists)
// Create one.
_, _ = repo.Create(ctx, f1.ID, f2.ID, 0.95)
// Should exist in both directions.
exists, _ = repo.Exists(ctx, f1.ID, f2.ID)
assert.True(t, exists)
exists, _ = repo.Exists(ctx, f2.ID, f1.ID)
assert.True(t, exists, "bidirectional check should work")
}
func TestSynapseRepo_AcceptNonPending_Fails(t *testing.T) {
repo, factRepo := setupSynapseTest(t)
ctx := context.Background()
f1 := memory.NewFact("m", memory.LevelProject, "", "")
f2 := memory.NewFact("n", memory.LevelProject, "", "")
require.NoError(t, factRepo.Add(ctx, f1))
require.NoError(t, factRepo.Add(ctx, f2))
id, _ := repo.Create(ctx, f1.ID, f2.ID, 0.80)
_ = repo.Accept(ctx, id)
// Trying to accept again should fail.
err := repo.Accept(ctx, id)
assert.Error(t, err)
}