mirror of
https://github.com/syntrex-lab/gomcp.git
synced 2026-04-27 13:26:21 +02:00
initial: Syntrex extraction from sentinel-community (615 files)
This commit is contained in:
commit
2c50c993b1
175 changed files with 32396 additions and 0 deletions
220
internal/infrastructure/sqlite/causal_repo.go
Normal file
220
internal/infrastructure/sqlite/causal_repo.go
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/sentinel-community/gomcp/internal/domain/causal"
|
||||
)
|
||||
|
||||
// CausalRepo implements causal.CausalStore using SQLite.
|
||||
// Compatible with causal_chains.db schema.
|
||||
type CausalRepo struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
// NewCausalRepo creates a CausalRepo and ensures the schema exists.
|
||||
func NewCausalRepo(db *DB) (*CausalRepo, error) {
|
||||
repo := &CausalRepo{db: db}
|
||||
if err := repo.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("causal repo migrate: %w", err)
|
||||
}
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
func (r *CausalRepo) migrate() error {
|
||||
stmts := []string{
|
||||
`CREATE TABLE IF NOT EXISTS causal_nodes (
|
||||
id TEXT PRIMARY KEY,
|
||||
node_type TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
session_id TEXT,
|
||||
metadata TEXT DEFAULT '{}'
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS causal_edges (
|
||||
from_id TEXT NOT NULL,
|
||||
to_id TEXT NOT NULL,
|
||||
edge_type TEXT NOT NULL,
|
||||
strength REAL DEFAULT 1.0,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (from_id, to_id, edge_type),
|
||||
FOREIGN KEY (from_id) REFERENCES causal_nodes(id),
|
||||
FOREIGN KEY (to_id) REFERENCES causal_nodes(id)
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_nodes_type ON causal_nodes(node_type)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_nodes_session ON causal_nodes(session_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_edges_from ON causal_edges(from_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_edges_to ON causal_edges(to_id)`,
|
||||
}
|
||||
for _, s := range stmts {
|
||||
if _, err := r.db.Exec(s); err != nil {
|
||||
return fmt.Errorf("exec migration: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddNode inserts a causal node.
|
||||
func (r *CausalRepo) AddNode(ctx context.Context, node *causal.Node) error {
|
||||
if err := node.Validate(); err != nil {
|
||||
return fmt.Errorf("validate node: %w", err)
|
||||
}
|
||||
_, err := r.db.Exec(`INSERT INTO causal_nodes (id, node_type, content, created_at)
|
||||
VALUES (?, ?, ?, ?)`,
|
||||
node.ID, string(node.Type), node.Content, node.CreatedAt.Format(timeFormat),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert node: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddEdge inserts a causal edge.
|
||||
func (r *CausalRepo) AddEdge(ctx context.Context, edge *causal.Edge) error {
|
||||
if err := edge.Validate(); err != nil {
|
||||
return fmt.Errorf("validate edge: %w", err)
|
||||
}
|
||||
_, err := r.db.Exec(`INSERT INTO causal_edges (from_id, to_id, edge_type)
|
||||
VALUES (?, ?, ?)`,
|
||||
edge.FromID, edge.ToID, string(edge.Type),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert edge: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetChain builds a causal chain around a decision node matching the query.
|
||||
func (r *CausalRepo) GetChain(ctx context.Context, query string, maxDepth int) (*causal.Chain, error) {
|
||||
chain := &causal.Chain{}
|
||||
|
||||
// Find decision node matching query.
|
||||
row := r.db.QueryRow(`SELECT id, node_type, content, created_at
|
||||
FROM causal_nodes WHERE node_type = 'decision' AND content LIKE ? LIMIT 1`,
|
||||
"%"+query+"%")
|
||||
|
||||
var id, nodeType, content, createdAt string
|
||||
err := row.Scan(&id, &nodeType, &content, &createdAt)
|
||||
if err != nil {
|
||||
// No decision found — return empty chain.
|
||||
return chain, nil
|
||||
}
|
||||
|
||||
t, _ := time.Parse(timeFormat, createdAt)
|
||||
chain.Decision = &causal.Node{ID: id, Type: causal.NodeType(nodeType), Content: content, CreatedAt: t}
|
||||
chain.TotalNodes = 1
|
||||
|
||||
// Find all connected nodes via edges.
|
||||
// Incoming edges (nodes that point TO the decision).
|
||||
inRows, err := r.db.Query(`SELECT n.id, n.node_type, n.content, n.created_at, e.edge_type
|
||||
FROM causal_edges e JOIN causal_nodes n ON e.from_id = n.id
|
||||
WHERE e.to_id = ?`, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query incoming edges: %w", err)
|
||||
}
|
||||
defer inRows.Close()
|
||||
|
||||
for inRows.Next() {
|
||||
var nid, nt, nc, nca, et string
|
||||
if err := inRows.Scan(&nid, &nt, &nc, &nca, &et); err != nil {
|
||||
return nil, fmt.Errorf("scan incoming: %w", err)
|
||||
}
|
||||
tt, _ := time.Parse(timeFormat, nca)
|
||||
node := &causal.Node{ID: nid, Type: causal.NodeType(nt), Content: nc, CreatedAt: tt}
|
||||
chain.TotalNodes++
|
||||
|
||||
switch causal.EdgeType(et) {
|
||||
case causal.EdgeJustifies:
|
||||
chain.Reasons = append(chain.Reasons, node)
|
||||
case causal.EdgeConstrains:
|
||||
chain.Constraints = append(chain.Constraints, node)
|
||||
default:
|
||||
// Classify by node type if edge type doesn't match.
|
||||
switch causal.NodeType(nt) {
|
||||
case causal.NodeAlternative:
|
||||
chain.Alternatives = append(chain.Alternatives, node)
|
||||
case causal.NodeReason:
|
||||
chain.Reasons = append(chain.Reasons, node)
|
||||
case causal.NodeConstraint:
|
||||
chain.Constraints = append(chain.Constraints, node)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := inRows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Outgoing edges (nodes that the decision points TO).
|
||||
outRows, err := r.db.Query(`SELECT n.id, n.node_type, n.content, n.created_at, e.edge_type
|
||||
FROM causal_edges e JOIN causal_nodes n ON e.to_id = n.id
|
||||
WHERE e.from_id = ?`, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query outgoing edges: %w", err)
|
||||
}
|
||||
defer outRows.Close()
|
||||
|
||||
for outRows.Next() {
|
||||
var nid, nt, nc, nca, et string
|
||||
if err := outRows.Scan(&nid, &nt, &nc, &nca, &et); err != nil {
|
||||
return nil, fmt.Errorf("scan outgoing: %w", err)
|
||||
}
|
||||
tt, _ := time.Parse(timeFormat, nca)
|
||||
node := &causal.Node{ID: nid, Type: causal.NodeType(nt), Content: nc, CreatedAt: tt}
|
||||
chain.TotalNodes++
|
||||
|
||||
switch causal.EdgeType(et) {
|
||||
case causal.EdgeCauses:
|
||||
chain.Consequences = append(chain.Consequences, node)
|
||||
default:
|
||||
switch causal.NodeType(nt) {
|
||||
case causal.NodeConsequence:
|
||||
chain.Consequences = append(chain.Consequences, node)
|
||||
case causal.NodeAlternative:
|
||||
chain.Alternatives = append(chain.Alternatives, node)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := outRows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return chain, nil
|
||||
}
|
||||
|
||||
// Stats returns aggregate statistics about the causal store.
|
||||
func (r *CausalRepo) Stats(ctx context.Context) (*causal.CausalStats, error) {
|
||||
stats := &causal.CausalStats{
|
||||
ByType: make(map[causal.NodeType]int),
|
||||
}
|
||||
|
||||
row := r.db.QueryRow(`SELECT COUNT(*) FROM causal_nodes`)
|
||||
if err := row.Scan(&stats.TotalNodes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
row = r.db.QueryRow(`SELECT COUNT(*) FROM causal_edges`)
|
||||
if err := row.Scan(&stats.TotalEdges); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := r.db.Query(`SELECT node_type, COUNT(*) FROM causal_nodes GROUP BY node_type`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var nt string
|
||||
var count int
|
||||
if err := rows.Scan(&nt, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.ByType[causal.NodeType(nt)] = count
|
||||
}
|
||||
return stats, rows.Err()
|
||||
}
|
||||
|
||||
// Ensure CausalRepo implements causal.CausalStore.
|
||||
var _ causal.CausalStore = (*CausalRepo)(nil)
|
||||
137
internal/infrastructure/sqlite/causal_repo_test.go
Normal file
137
internal/infrastructure/sqlite/causal_repo_test.go
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/sentinel-community/gomcp/internal/domain/causal"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestCausalRepo(t *testing.T) *CausalRepo {
|
||||
t.Helper()
|
||||
db, err := OpenMemory()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { db.Close() })
|
||||
|
||||
repo, err := NewCausalRepo(db)
|
||||
require.NoError(t, err)
|
||||
return repo
|
||||
}
|
||||
|
||||
func TestCausalRepo_AddNode_GetChain(t *testing.T) {
|
||||
repo := newTestCausalRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
decision := causal.NewNode(causal.NodeDecision, "Use SQLite")
|
||||
reason := causal.NewNode(causal.NodeReason, "Embedded, no server needed")
|
||||
consequence := causal.NewNode(causal.NodeConsequence, "Single binary deployment")
|
||||
|
||||
require.NoError(t, repo.AddNode(ctx, decision))
|
||||
require.NoError(t, repo.AddNode(ctx, reason))
|
||||
require.NoError(t, repo.AddNode(ctx, consequence))
|
||||
|
||||
e1 := causal.NewEdge(reason.ID, decision.ID, causal.EdgeJustifies)
|
||||
e2 := causal.NewEdge(decision.ID, consequence.ID, causal.EdgeCauses)
|
||||
require.NoError(t, repo.AddEdge(ctx, e1))
|
||||
require.NoError(t, repo.AddEdge(ctx, e2))
|
||||
|
||||
chain, err := repo.GetChain(ctx, "SQLite", 3)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, chain)
|
||||
assert.NotNil(t, chain.Decision)
|
||||
assert.Equal(t, "Use SQLite", chain.Decision.Content)
|
||||
assert.Len(t, chain.Reasons, 1)
|
||||
assert.Len(t, chain.Consequences, 1)
|
||||
}
|
||||
|
||||
func TestCausalRepo_AddNode_Duplicate(t *testing.T) {
|
||||
repo := newTestCausalRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
node := causal.NewNode(causal.NodeDecision, "test")
|
||||
require.NoError(t, repo.AddNode(ctx, node))
|
||||
|
||||
err := repo.AddNode(ctx, node)
|
||||
assert.Error(t, err) // duplicate primary key
|
||||
}
|
||||
|
||||
func TestCausalRepo_AddEdge_SelfLoop(t *testing.T) {
|
||||
repo := newTestCausalRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
node := causal.NewNode(causal.NodeDecision, "test")
|
||||
require.NoError(t, repo.AddNode(ctx, node))
|
||||
|
||||
edge := causal.NewEdge(node.ID, node.ID, causal.EdgeCauses)
|
||||
err := repo.AddEdge(ctx, edge)
|
||||
assert.Error(t, err) // self-loop validation
|
||||
}
|
||||
|
||||
func TestCausalRepo_GetChain_NoResults(t *testing.T) {
|
||||
repo := newTestCausalRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
chain, err := repo.GetChain(ctx, "nonexistent", 3)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, chain.TotalNodes)
|
||||
}
|
||||
|
||||
func TestCausalRepo_Stats(t *testing.T) {
|
||||
repo := newTestCausalRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
n1 := causal.NewNode(causal.NodeDecision, "D1")
|
||||
n2 := causal.NewNode(causal.NodeReason, "R1")
|
||||
n3 := causal.NewNode(causal.NodeConsequence, "C1")
|
||||
require.NoError(t, repo.AddNode(ctx, n1))
|
||||
require.NoError(t, repo.AddNode(ctx, n2))
|
||||
require.NoError(t, repo.AddNode(ctx, n3))
|
||||
|
||||
e1 := causal.NewEdge(n2.ID, n1.ID, causal.EdgeJustifies)
|
||||
require.NoError(t, repo.AddEdge(ctx, e1))
|
||||
|
||||
stats, err := repo.Stats(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, stats.TotalNodes)
|
||||
assert.Equal(t, 1, stats.TotalEdges)
|
||||
assert.Equal(t, 1, stats.ByType[causal.NodeDecision])
|
||||
assert.Equal(t, 1, stats.ByType[causal.NodeReason])
|
||||
assert.Equal(t, 1, stats.ByType[causal.NodeConsequence])
|
||||
}
|
||||
|
||||
func TestCausalRepo_ComplexChain(t *testing.T) {
|
||||
repo := newTestCausalRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
decision := causal.NewNode(causal.NodeDecision, "Use Go for MCP server")
|
||||
r1 := causal.NewNode(causal.NodeReason, "Performance")
|
||||
r2 := causal.NewNode(causal.NodeReason, "Single binary")
|
||||
c1 := causal.NewNode(causal.NodeConsequence, "Faster startup")
|
||||
cn1 := causal.NewNode(causal.NodeConstraint, "Must support CGO-free")
|
||||
a1 := causal.NewNode(causal.NodeAlternative, "Use Rust")
|
||||
|
||||
for _, n := range []*causal.Node{decision, r1, r2, c1, cn1, a1} {
|
||||
require.NoError(t, repo.AddNode(ctx, n))
|
||||
}
|
||||
|
||||
edges := []*causal.Edge{
|
||||
causal.NewEdge(r1.ID, decision.ID, causal.EdgeJustifies),
|
||||
causal.NewEdge(r2.ID, decision.ID, causal.EdgeJustifies),
|
||||
causal.NewEdge(decision.ID, c1.ID, causal.EdgeCauses),
|
||||
causal.NewEdge(cn1.ID, decision.ID, causal.EdgeConstrains),
|
||||
}
|
||||
for _, e := range edges {
|
||||
require.NoError(t, repo.AddEdge(ctx, e))
|
||||
}
|
||||
|
||||
chain, err := repo.GetChain(ctx, "Go for MCP", 5)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, chain.Decision)
|
||||
assert.Equal(t, "Use Go for MCP server", chain.Decision.Content)
|
||||
assert.Len(t, chain.Reasons, 2)
|
||||
assert.Len(t, chain.Consequences, 1)
|
||||
assert.Len(t, chain.Constraints, 1)
|
||||
assert.GreaterOrEqual(t, chain.TotalNodes, 5)
|
||||
}
|
||||
254
internal/infrastructure/sqlite/crystal_repo.go
Normal file
254
internal/infrastructure/sqlite/crystal_repo.go
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/sentinel-community/gomcp/internal/domain/crystal"
|
||||
)
|
||||
|
||||
// CrystalRepo implements crystal.CrystalStore using SQLite.
|
||||
// Compatible with crystals.db schema.
|
||||
type CrystalRepo struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
// NewCrystalRepo creates a CrystalRepo and ensures the schema exists.
|
||||
func NewCrystalRepo(db *DB) (*CrystalRepo, error) {
|
||||
repo := &CrystalRepo{db: db}
|
||||
if err := repo.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("crystal repo migrate: %w", err)
|
||||
}
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
func (r *CrystalRepo) migrate() error {
|
||||
stmts := []string{
|
||||
`CREATE TABLE IF NOT EXISTS crystals (
|
||||
path TEXT PRIMARY KEY,
|
||||
name TEXT,
|
||||
content BLOB,
|
||||
primitives_count INTEGER,
|
||||
token_count INTEGER,
|
||||
indexed_at REAL,
|
||||
source_mtime REAL,
|
||||
source_hash TEXT,
|
||||
last_validated REAL,
|
||||
human_confirmed INTEGER DEFAULT 0
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS metadata (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_mtime ON crystals(source_mtime)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_indexed ON crystals(indexed_at)`,
|
||||
}
|
||||
for _, s := range stmts {
|
||||
if _, err := r.db.Exec(s); err != nil {
|
||||
return fmt.Errorf("exec migration: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// serializedCrystal is the JSON structure stored in the content BLOB.
|
||||
type serializedCrystal struct {
|
||||
Path string `json:"path"`
|
||||
Name string `json:"name"`
|
||||
TokenCount int `json:"token_count"`
|
||||
ContentHash string `json:"content_hash"`
|
||||
Primitives []crystal.Primitive `json:"primitives"`
|
||||
}
|
||||
|
||||
// Upsert inserts or replaces a crystal.
|
||||
func (r *CrystalRepo) Upsert(ctx context.Context, c *crystal.Crystal) error {
|
||||
sc := serializedCrystal{
|
||||
Path: c.Path,
|
||||
Name: c.Name,
|
||||
TokenCount: c.TokenCount,
|
||||
ContentHash: c.ContentHash,
|
||||
Primitives: c.Primitives,
|
||||
}
|
||||
blob, err := json.Marshal(sc)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal crystal: %w", err)
|
||||
}
|
||||
|
||||
_, err = r.db.Exec(`INSERT OR REPLACE INTO crystals
|
||||
(path, name, content, primitives_count, token_count,
|
||||
indexed_at, source_mtime, source_hash, last_validated, human_confirmed)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
c.Path, c.Name, blob, c.PrimitivesCount, c.TokenCount,
|
||||
c.IndexedAt, c.SourceMtime, c.SourceHash,
|
||||
c.LastValidated, boolToInt(c.HumanConfirmed),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("upsert crystal: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a crystal by path.
|
||||
func (r *CrystalRepo) Get(ctx context.Context, path string) (*crystal.Crystal, error) {
|
||||
row := r.db.QueryRow(`SELECT path, name, content, primitives_count, token_count,
|
||||
indexed_at, source_mtime, source_hash, last_validated, human_confirmed
|
||||
FROM crystals WHERE path = ?`, path)
|
||||
|
||||
var c crystal.Crystal
|
||||
var blob []byte
|
||||
var lastValidated sql.NullFloat64
|
||||
var humanConfirmed int
|
||||
|
||||
err := row.Scan(&c.Path, &c.Name, &blob, &c.PrimitivesCount, &c.TokenCount,
|
||||
&c.IndexedAt, &c.SourceMtime, &c.SourceHash, &lastValidated, &humanConfirmed)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("crystal %s not found", path)
|
||||
}
|
||||
return nil, fmt.Errorf("scan crystal: %w", err)
|
||||
}
|
||||
|
||||
if lastValidated.Valid {
|
||||
c.LastValidated = lastValidated.Float64
|
||||
}
|
||||
c.HumanConfirmed = humanConfirmed != 0
|
||||
|
||||
if len(blob) > 0 {
|
||||
var sc serializedCrystal
|
||||
if err := json.Unmarshal(blob, &sc); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal crystal content: %w", err)
|
||||
}
|
||||
c.Primitives = sc.Primitives
|
||||
c.ContentHash = sc.ContentHash
|
||||
}
|
||||
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
// Delete removes a crystal by path.
|
||||
func (r *CrystalRepo) Delete(ctx context.Context, path string) error {
|
||||
result, err := r.db.Exec(`DELETE FROM crystals WHERE path = ?`, path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete crystal: %w", err)
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
if n == 0 {
|
||||
return fmt.Errorf("crystal %s not found", path)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// List returns crystals matching a path pattern. Empty pattern returns all.
|
||||
func (r *CrystalRepo) List(ctx context.Context, pattern string, limit int) ([]*crystal.Crystal, error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if pattern == "" {
|
||||
rows, err = r.db.Query(`SELECT path, name, content, primitives_count, token_count,
|
||||
indexed_at, source_mtime, source_hash, last_validated, human_confirmed
|
||||
FROM crystals ORDER BY indexed_at DESC LIMIT ?`, limit)
|
||||
} else {
|
||||
rows, err = r.db.Query(`SELECT path, name, content, primitives_count, token_count,
|
||||
indexed_at, source_mtime, source_hash, last_validated, human_confirmed
|
||||
FROM crystals WHERE path LIKE ? ORDER BY indexed_at DESC LIMIT ?`, pattern, limit)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list crystals: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanCrystals(rows)
|
||||
}
|
||||
|
||||
// Search searches crystal primitives by name/value containing query.
|
||||
func (r *CrystalRepo) Search(ctx context.Context, query string, limit int) ([]*crystal.Crystal, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
|
||||
// Search in content BLOB (JSON text stored as BLOB) for primitive names/values.
|
||||
// CAST to TEXT required because SQLite LIKE doesn't match on raw BLOB.
|
||||
rows, err := r.db.Query(`SELECT path, name, content, primitives_count, token_count,
|
||||
indexed_at, source_mtime, source_hash, last_validated, human_confirmed
|
||||
FROM crystals WHERE CAST(content AS TEXT) LIKE ? OR name LIKE ? LIMIT ?`,
|
||||
"%"+query+"%", "%"+query+"%", limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search crystals: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanCrystals(rows)
|
||||
}
|
||||
|
||||
// Stats returns aggregate statistics.
|
||||
func (r *CrystalRepo) Stats(ctx context.Context) (*crystal.CrystalStats, error) {
|
||||
stats := &crystal.CrystalStats{
|
||||
ByExtension: make(map[string]int),
|
||||
}
|
||||
|
||||
row := r.db.QueryRow(`SELECT COUNT(*), COALESCE(SUM(primitives_count),0), COALESCE(SUM(token_count),0) FROM crystals`)
|
||||
if err := row.Scan(&stats.TotalCrystals, &stats.TotalPrimitives, &stats.TotalTokens); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Count by file extension.
|
||||
rows, err := r.db.Query(`SELECT path FROM crystals`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var p string
|
||||
if err := rows.Scan(&p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ext := strings.ToLower(filepath.Ext(p))
|
||||
if ext == "" {
|
||||
ext = "(no ext)"
|
||||
}
|
||||
stats.ByExtension[ext]++
|
||||
}
|
||||
return stats, rows.Err()
|
||||
}
|
||||
|
||||
func scanCrystals(rows *sql.Rows) ([]*crystal.Crystal, error) {
|
||||
var result []*crystal.Crystal
|
||||
for rows.Next() {
|
||||
var c crystal.Crystal
|
||||
var blob []byte
|
||||
var lastValidated sql.NullFloat64
|
||||
var humanConfirmed int
|
||||
|
||||
err := rows.Scan(&c.Path, &c.Name, &blob, &c.PrimitivesCount, &c.TokenCount,
|
||||
&c.IndexedAt, &c.SourceMtime, &c.SourceHash, &lastValidated, &humanConfirmed)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scan crystal row: %w", err)
|
||||
}
|
||||
|
||||
if lastValidated.Valid {
|
||||
c.LastValidated = lastValidated.Float64
|
||||
}
|
||||
c.HumanConfirmed = humanConfirmed != 0
|
||||
|
||||
if len(blob) > 0 {
|
||||
var sc serializedCrystal
|
||||
if err := json.Unmarshal(blob, &sc); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal crystal content: %w", err)
|
||||
}
|
||||
c.Primitives = sc.Primitives
|
||||
c.ContentHash = sc.ContentHash
|
||||
}
|
||||
|
||||
result = append(result, &c)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
// Ensure CrystalRepo implements crystal.CrystalStore.
|
||||
var _ crystal.CrystalStore = (*CrystalRepo)(nil)
|
||||
160
internal/infrastructure/sqlite/crystal_repo_test.go
Normal file
160
internal/infrastructure/sqlite/crystal_repo_test.go
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/sentinel-community/gomcp/internal/domain/crystal"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestCrystalRepo(t *testing.T) *CrystalRepo {
|
||||
t.Helper()
|
||||
db, err := OpenMemory()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { db.Close() })
|
||||
|
||||
repo, err := NewCrystalRepo(db)
|
||||
require.NoError(t, err)
|
||||
return repo
|
||||
}
|
||||
|
||||
func makeCrystal(path, name string) *crystal.Crystal {
|
||||
return &crystal.Crystal{
|
||||
Path: path,
|
||||
Name: name,
|
||||
TokenCount: 150,
|
||||
ContentHash: "hash123",
|
||||
PrimitivesCount: 2,
|
||||
Primitives: []crystal.Primitive{
|
||||
{PType: "function", Name: "main", Value: "func main()", SourceLine: 1, Confidence: 1.0},
|
||||
{PType: "function", Name: "init", Value: "func init()", SourceLine: 5, Confidence: 0.9},
|
||||
},
|
||||
IndexedAt: 1700000000.0,
|
||||
SourceMtime: 1699999000.0,
|
||||
SourceHash: "src_hash",
|
||||
}
|
||||
}
|
||||
|
||||
func TestCrystalRepo_Upsert_Get(t *testing.T) {
|
||||
repo := newTestCrystalRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
c := makeCrystal("cmd/main.go", "main.go")
|
||||
require.NoError(t, repo.Upsert(ctx, c))
|
||||
|
||||
got, err := repo.Get(ctx, "cmd/main.go")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
|
||||
assert.Equal(t, "cmd/main.go", got.Path)
|
||||
assert.Equal(t, "main.go", got.Name)
|
||||
assert.Equal(t, 150, got.TokenCount)
|
||||
assert.Equal(t, 2, got.PrimitivesCount)
|
||||
assert.Len(t, got.Primitives, 2)
|
||||
assert.Equal(t, "function", got.Primitives[0].PType)
|
||||
}
|
||||
|
||||
func TestCrystalRepo_Upsert_Overwrite(t *testing.T) {
|
||||
repo := newTestCrystalRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
c := makeCrystal("main.go", "main.go")
|
||||
require.NoError(t, repo.Upsert(ctx, c))
|
||||
|
||||
c.TokenCount = 300
|
||||
c.PrimitivesCount = 5
|
||||
require.NoError(t, repo.Upsert(ctx, c))
|
||||
|
||||
got, err := repo.Get(ctx, "main.go")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 300, got.TokenCount)
|
||||
assert.Equal(t, 5, got.PrimitivesCount)
|
||||
}
|
||||
|
||||
func TestCrystalRepo_Get_NotFound(t *testing.T) {
|
||||
repo := newTestCrystalRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
got, err := repo.Get(ctx, "nonexistent.go")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, got)
|
||||
}
|
||||
|
||||
func TestCrystalRepo_Delete(t *testing.T) {
|
||||
repo := newTestCrystalRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
c := makeCrystal("delete_me.go", "delete_me.go")
|
||||
require.NoError(t, repo.Upsert(ctx, c))
|
||||
require.NoError(t, repo.Delete(ctx, "delete_me.go"))
|
||||
|
||||
got, err := repo.Get(ctx, "delete_me.go")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, got)
|
||||
}
|
||||
|
||||
func TestCrystalRepo_List(t *testing.T) {
|
||||
repo := newTestCrystalRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for _, p := range []string{"cmd/main.go", "internal/foo.go", "internal/bar.go", "README.md"} {
|
||||
require.NoError(t, repo.Upsert(ctx, makeCrystal(p, p)))
|
||||
}
|
||||
|
||||
// List all
|
||||
all, err := repo.List(ctx, "", 100)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, all, 4)
|
||||
|
||||
// List with pattern
|
||||
internal, err := repo.List(ctx, "internal%", 100)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, internal, 2)
|
||||
}
|
||||
|
||||
func TestCrystalRepo_Search(t *testing.T) {
|
||||
repo := newTestCrystalRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
c1 := makeCrystal("server.go", "server.go")
|
||||
c1.Primitives = []crystal.Primitive{
|
||||
{PType: "function", Name: "handleRequest", Value: "func handleRequest()", SourceLine: 10, Confidence: 1.0},
|
||||
}
|
||||
c2 := makeCrystal("client.go", "client.go")
|
||||
c2.Primitives = []crystal.Primitive{
|
||||
{PType: "function", Name: "sendRequest", Value: "func sendRequest()", SourceLine: 5, Confidence: 1.0},
|
||||
}
|
||||
c3 := makeCrystal("utils.go", "utils.go")
|
||||
c3.Primitives = []crystal.Primitive{
|
||||
{PType: "function", Name: "helper", Value: "func helper()", SourceLine: 1, Confidence: 1.0},
|
||||
}
|
||||
|
||||
for _, c := range []*crystal.Crystal{c1, c2, c3} {
|
||||
require.NoError(t, repo.Upsert(ctx, c))
|
||||
}
|
||||
|
||||
results, err := repo.Search(ctx, "Request", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
}
|
||||
|
||||
func TestCrystalRepo_Stats(t *testing.T) {
|
||||
repo := newTestCrystalRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
c1 := makeCrystal("main.go", "main.go")
|
||||
c1.TokenCount = 100
|
||||
c2 := makeCrystal("server.py", "server.py")
|
||||
c2.TokenCount = 200
|
||||
c2.PrimitivesCount = 5
|
||||
|
||||
require.NoError(t, repo.Upsert(ctx, c1))
|
||||
require.NoError(t, repo.Upsert(ctx, c2))
|
||||
|
||||
stats, err := repo.Stats(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, stats.TotalCrystals)
|
||||
assert.Equal(t, 300, stats.TotalTokens)
|
||||
}
|
||||
105
internal/infrastructure/sqlite/db.go
Normal file
105
internal/infrastructure/sqlite/db.go
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
// Package sqlite provides SQLite-based persistence using modernc.org/sqlite (pure Go, no CGO).
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
// DB wraps a *sql.DB with SQLite-specific configuration.
|
||||
type DB struct {
|
||||
db *sql.DB
|
||||
path string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Open opens or creates an SQLite database at the given path.
|
||||
// It applies WAL mode and recommended pragmas for performance.
|
||||
func Open(path string) (*DB, error) {
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create db directory: %w", err)
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite", path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open sqlite: %w", err)
|
||||
}
|
||||
|
||||
// Apply performance pragmas.
|
||||
pragmas := []string{
|
||||
"PRAGMA journal_mode=WAL",
|
||||
"PRAGMA synchronous=NORMAL",
|
||||
"PRAGMA cache_size=-64000", // 64MB
|
||||
"PRAGMA foreign_keys=ON",
|
||||
"PRAGMA busy_timeout=5000",
|
||||
}
|
||||
for _, p := range pragmas {
|
||||
if _, err := db.Exec(p); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("exec pragma %q: %w", p, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Connection pool settings for SQLite (single writer).
|
||||
db.SetMaxOpenConns(1)
|
||||
db.SetMaxIdleConns(1)
|
||||
|
||||
return &DB{db: db, path: path}, nil
|
||||
}
|
||||
|
||||
// OpenMemory opens an in-memory SQLite database (for testing).
|
||||
func OpenMemory() (*DB, error) {
|
||||
db, err := sql.Open("sqlite", ":memory:")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open in-memory sqlite: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec("PRAGMA foreign_keys=ON"); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("enable foreign keys: %w", err)
|
||||
}
|
||||
|
||||
return &DB{db: db, path: ":memory:"}, nil
|
||||
}
|
||||
|
||||
// SqlDB returns the underlying *sql.DB.
|
||||
func (d *DB) SqlDB() *sql.DB {
|
||||
return d.db
|
||||
}
|
||||
|
||||
// Path returns the database file path.
|
||||
func (d *DB) Path() string {
|
||||
return d.path
|
||||
}
|
||||
|
||||
// Close closes the database connection.
|
||||
func (d *DB) Close() error {
|
||||
return d.db.Close()
|
||||
}
|
||||
|
||||
// Exec executes a query that doesn't return rows.
|
||||
func (d *DB) Exec(query string, args ...any) (sql.Result, error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
return d.db.Exec(query, args...)
|
||||
}
|
||||
|
||||
// Query executes a query that returns rows.
|
||||
func (d *DB) Query(query string, args ...any) (*sql.Rows, error) {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
return d.db.Query(query, args...)
|
||||
}
|
||||
|
||||
// QueryRow executes a query that returns at most one row.
|
||||
func (d *DB) QueryRow(query string, args ...any) *sql.Row {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
return d.db.QueryRow(query, args...)
|
||||
}
|
||||
698
internal/infrastructure/sqlite/fact_repo.go
Normal file
698
internal/infrastructure/sqlite/fact_repo.go
Normal file
|
|
@ -0,0 +1,698 @@
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/sentinel-community/gomcp/internal/domain/memory"
|
||||
)
|
||||
|
||||
const timeFormat = time.RFC3339Nano
|
||||
|
||||
// FactRepo implements memory.FactStore using SQLite.
|
||||
// Compatible with memory_bridge_v2.db schema v2.0.0.
|
||||
type FactRepo struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
// NewFactRepo creates a FactRepo and ensures the schema exists.
|
||||
func NewFactRepo(db *DB) (*FactRepo, error) {
|
||||
repo := &FactRepo{db: db}
|
||||
if err := repo.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("fact repo migrate: %w", err)
|
||||
}
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
func (r *FactRepo) migrate() error {
|
||||
stmts := []string{
|
||||
`CREATE TABLE IF NOT EXISTS hierarchical_facts (
|
||||
id TEXT PRIMARY KEY,
|
||||
content TEXT NOT NULL,
|
||||
level INTEGER NOT NULL DEFAULT 0,
|
||||
domain TEXT,
|
||||
module TEXT,
|
||||
code_ref TEXT,
|
||||
parent_id TEXT,
|
||||
embedding BLOB,
|
||||
ttl_config TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
valid_from TEXT NOT NULL,
|
||||
valid_until TEXT,
|
||||
is_stale INTEGER DEFAULT 0,
|
||||
is_archived INTEGER DEFAULT 0,
|
||||
confidence REAL DEFAULT 1.0,
|
||||
source TEXT DEFAULT 'manual',
|
||||
session_id TEXT,
|
||||
FOREIGN KEY (parent_id) REFERENCES hierarchical_facts(id)
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS fact_hierarchy (
|
||||
parent_id TEXT NOT NULL,
|
||||
child_id TEXT NOT NULL,
|
||||
relationship TEXT DEFAULT 'contains',
|
||||
PRIMARY KEY (parent_id, child_id),
|
||||
FOREIGN KEY (parent_id) REFERENCES hierarchical_facts(id),
|
||||
FOREIGN KEY (child_id) REFERENCES hierarchical_facts(id)
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS embeddings_index (
|
||||
fact_id TEXT PRIMARY KEY,
|
||||
embedding BLOB NOT NULL,
|
||||
model_name TEXT DEFAULT 'all-MiniLM-L6-v2',
|
||||
updated_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (fact_id) REFERENCES hierarchical_facts(id)
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS domain_centroids (
|
||||
domain TEXT PRIMARY KEY,
|
||||
centroid BLOB NOT NULL,
|
||||
fact_count INTEGER DEFAULT 0,
|
||||
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS schema_info (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_facts_level ON hierarchical_facts(level)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_facts_domain ON hierarchical_facts(domain)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_facts_module ON hierarchical_facts(module)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_facts_stale ON hierarchical_facts(is_stale)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_facts_session ON hierarchical_facts(session_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_facts_source ON hierarchical_facts(source)`,
|
||||
`INSERT OR REPLACE INTO schema_info (key, value) VALUES ('version', '3.3.0')`,
|
||||
}
|
||||
|
||||
// v3.3 migration: add hit_count, last_accessed_at, synapses table.
|
||||
v33Stmts := []string{
|
||||
// Safe ALTER TABLE — ignore error if column already exists.
|
||||
`ALTER TABLE hierarchical_facts ADD COLUMN hit_count INTEGER DEFAULT 0`,
|
||||
`ALTER TABLE hierarchical_facts ADD COLUMN last_accessed_at TEXT`,
|
||||
`CREATE TABLE IF NOT EXISTS synapses (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
fact_id_a TEXT NOT NULL,
|
||||
fact_id_b TEXT NOT NULL,
|
||||
confidence REAL DEFAULT 0.0,
|
||||
status TEXT DEFAULT 'PENDING',
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (fact_id_a) REFERENCES hierarchical_facts(id),
|
||||
FOREIGN KEY (fact_id_b) REFERENCES hierarchical_facts(id)
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_facts_hit_count ON hierarchical_facts(hit_count)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_synapses_status ON synapses(status)`,
|
||||
}
|
||||
for _, s := range stmts {
|
||||
if _, err := r.db.Exec(s); err != nil {
|
||||
return fmt.Errorf("exec %q: %w", s[:40], err)
|
||||
}
|
||||
}
|
||||
// v3.3: ignore errors on ALTER TABLE (column may already exist).
|
||||
for _, s := range v33Stmts {
|
||||
_, _ = r.db.Exec(s)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add inserts a new fact.
|
||||
func (r *FactRepo) Add(ctx context.Context, fact *memory.Fact) error {
|
||||
embeddingBlob, err := encodeEmbedding(fact.Embedding)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ttlJSON, err := encodeTTL(fact.TTL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var validUntil *string
|
||||
if fact.ValidUntil != nil {
|
||||
s := fact.ValidUntil.Format(timeFormat)
|
||||
validUntil = &s
|
||||
}
|
||||
|
||||
_, err = r.db.Exec(`INSERT INTO hierarchical_facts
|
||||
(id, content, level, domain, module, code_ref, parent_id,
|
||||
embedding, ttl_config, created_at, valid_from, valid_until,
|
||||
is_stale, is_archived, confidence, source, session_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
fact.ID, fact.Content, int(fact.Level),
|
||||
nullStr(fact.Domain), nullStr(fact.Module), nullStr(fact.CodeRef),
|
||||
nullStr(fact.ParentID),
|
||||
embeddingBlob, ttlJSON,
|
||||
fact.CreatedAt.Format(timeFormat), fact.ValidFrom.Format(timeFormat), validUntil,
|
||||
boolToInt(fact.IsStale), boolToInt(fact.IsArchived),
|
||||
fact.Confidence, fact.Source, nullStr(fact.SessionID),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert fact: %w", err)
|
||||
}
|
||||
|
||||
// Also insert into embeddings_index if embedding exists.
|
||||
if len(fact.Embedding) > 0 {
|
||||
_, err = r.db.Exec(`INSERT OR REPLACE INTO embeddings_index (fact_id, embedding) VALUES (?, ?)`,
|
||||
fact.ID, embeddingBlob)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert embedding index: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a fact by ID.
|
||||
func (r *FactRepo) Get(ctx context.Context, id string) (*memory.Fact, error) {
|
||||
row := r.db.QueryRow(`SELECT id, content, level, domain, module, code_ref, parent_id,
|
||||
embedding, ttl_config, created_at, valid_from, valid_until,
|
||||
is_stale, is_archived, confidence, source, session_id
|
||||
FROM hierarchical_facts WHERE id = ?`, id)
|
||||
return scanFact(row)
|
||||
}
|
||||
|
||||
// Update updates an existing fact.
|
||||
func (r *FactRepo) Update(ctx context.Context, fact *memory.Fact) error {
|
||||
embeddingBlob, err := encodeEmbedding(fact.Embedding)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ttlJSON, err := encodeTTL(fact.TTL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var validUntil *string
|
||||
if fact.ValidUntil != nil {
|
||||
s := fact.ValidUntil.Format(timeFormat)
|
||||
validUntil = &s
|
||||
}
|
||||
|
||||
result, err := r.db.Exec(`UPDATE hierarchical_facts SET
|
||||
content=?, level=?, domain=?, module=?, code_ref=?, parent_id=?,
|
||||
embedding=?, ttl_config=?, created_at=?, valid_from=?, valid_until=?,
|
||||
is_stale=?, is_archived=?, confidence=?, source=?, session_id=?
|
||||
WHERE id=?`,
|
||||
fact.Content, int(fact.Level),
|
||||
nullStr(fact.Domain), nullStr(fact.Module), nullStr(fact.CodeRef),
|
||||
nullStr(fact.ParentID),
|
||||
embeddingBlob, ttlJSON,
|
||||
fact.CreatedAt.Format(timeFormat), fact.ValidFrom.Format(timeFormat), validUntil,
|
||||
boolToInt(fact.IsStale), boolToInt(fact.IsArchived),
|
||||
fact.Confidence, fact.Source, nullStr(fact.SessionID),
|
||||
fact.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update fact: %w", err)
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
if n == 0 {
|
||||
return fmt.Errorf("fact %s not found", fact.ID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a fact by ID.
|
||||
func (r *FactRepo) Delete(ctx context.Context, id string) error {
|
||||
// Remove from embeddings_index first (FK).
|
||||
_, _ = r.db.Exec(`DELETE FROM embeddings_index WHERE fact_id = ?`, id)
|
||||
// Remove from fact_hierarchy.
|
||||
_, _ = r.db.Exec(`DELETE FROM fact_hierarchy WHERE parent_id = ? OR child_id = ?`, id, id)
|
||||
// Remove the fact.
|
||||
result, err := r.db.Exec(`DELETE FROM hierarchical_facts WHERE id = ?`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete fact: %w", err)
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
if n == 0 {
|
||||
return fmt.Errorf("fact %s not found", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListByDomain returns facts in a domain, optionally including stale ones.
|
||||
func (r *FactRepo) ListByDomain(ctx context.Context, domain string, includeStale bool) ([]*memory.Fact, error) {
|
||||
var query string
|
||||
if includeStale {
|
||||
query = `SELECT id, content, level, domain, module, code_ref, parent_id,
|
||||
embedding, ttl_config, created_at, valid_from, valid_until,
|
||||
is_stale, is_archived, confidence, source, session_id
|
||||
FROM hierarchical_facts WHERE domain = ?`
|
||||
} else {
|
||||
query = `SELECT id, content, level, domain, module, code_ref, parent_id,
|
||||
embedding, ttl_config, created_at, valid_from, valid_until,
|
||||
is_stale, is_archived, confidence, source, session_id
|
||||
FROM hierarchical_facts WHERE domain = ? AND is_stale = 0`
|
||||
}
|
||||
rows, err := r.db.Query(query, domain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list by domain: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanFacts(rows)
|
||||
}
|
||||
|
||||
// ListByLevel returns all facts at a given hierarchy level.
|
||||
func (r *FactRepo) ListByLevel(ctx context.Context, level memory.HierLevel) ([]*memory.Fact, error) {
|
||||
rows, err := r.db.Query(`SELECT id, content, level, domain, module, code_ref, parent_id,
|
||||
embedding, ttl_config, created_at, valid_from, valid_until,
|
||||
is_stale, is_archived, confidence, source, session_id
|
||||
FROM hierarchical_facts WHERE level = ?`, int(level))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list by level: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanFacts(rows)
|
||||
}
|
||||
|
||||
// ListDomains returns distinct domain names.
|
||||
func (r *FactRepo) ListDomains(ctx context.Context) ([]string, error) {
|
||||
rows, err := r.db.Query(`SELECT DISTINCT domain FROM hierarchical_facts WHERE domain IS NOT NULL AND domain != ''`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list domains: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var domains []string
|
||||
for rows.Next() {
|
||||
var d string
|
||||
if err := rows.Scan(&d); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
domains = append(domains, d)
|
||||
}
|
||||
return domains, rows.Err()
|
||||
}
|
||||
|
||||
// GetStale returns stale facts, optionally including archived ones.
|
||||
func (r *FactRepo) GetStale(ctx context.Context, includeArchived bool) ([]*memory.Fact, error) {
|
||||
var query string
|
||||
if includeArchived {
|
||||
query = `SELECT id, content, level, domain, module, code_ref, parent_id,
|
||||
embedding, ttl_config, created_at, valid_from, valid_until,
|
||||
is_stale, is_archived, confidence, source, session_id
|
||||
FROM hierarchical_facts WHERE is_stale = 1`
|
||||
} else {
|
||||
query = `SELECT id, content, level, domain, module, code_ref, parent_id,
|
||||
embedding, ttl_config, created_at, valid_from, valid_until,
|
||||
is_stale, is_archived, confidence, source, session_id
|
||||
FROM hierarchical_facts WHERE is_stale = 1 AND is_archived = 0`
|
||||
}
|
||||
rows, err := r.db.Query(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get stale: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanFacts(rows)
|
||||
}
|
||||
|
||||
// Search performs a LIKE-based text search on fact content.
|
||||
func (r *FactRepo) Search(ctx context.Context, query string, limit int) ([]*memory.Fact, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
rows, err := r.db.Query(`SELECT id, content, level, domain, module, code_ref, parent_id,
|
||||
embedding, ttl_config, created_at, valid_from, valid_until,
|
||||
is_stale, is_archived, confidence, source, session_id
|
||||
FROM hierarchical_facts WHERE content LIKE ? LIMIT ?`,
|
||||
"%"+query+"%", limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanFacts(rows)
|
||||
}
|
||||
|
||||
// GetExpired returns facts whose TTL has expired.
|
||||
func (r *FactRepo) GetExpired(ctx context.Context) ([]*memory.Fact, error) {
|
||||
// Get all facts with TTL config, check expiry in Go.
|
||||
rows, err := r.db.Query(`SELECT id, content, level, domain, module, code_ref, parent_id,
|
||||
embedding, ttl_config, created_at, valid_from, valid_until,
|
||||
is_stale, is_archived, confidence, source, session_id
|
||||
FROM hierarchical_facts WHERE ttl_config IS NOT NULL AND ttl_config != ''`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get expired: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
all, err := scanFacts(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var expired []*memory.Fact
|
||||
for _, f := range all {
|
||||
if f.TTL != nil && f.TTL.IsExpired(f.CreatedAt) {
|
||||
expired = append(expired, f)
|
||||
}
|
||||
}
|
||||
return expired, nil
|
||||
}
|
||||
|
||||
// RefreshTTL resets the created_at timestamp for a fact (effectively refreshing its TTL).
|
||||
func (r *FactRepo) RefreshTTL(ctx context.Context, id string) error {
|
||||
now := time.Now().Format(timeFormat)
|
||||
result, err := r.db.Exec(`UPDATE hierarchical_facts SET created_at = ? WHERE id = ?`, now, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("refresh ttl: %w", err)
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
if n == 0 {
|
||||
return fmt.Errorf("fact %s not found", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stats returns aggregate statistics about the fact store.
|
||||
func (r *FactRepo) Stats(ctx context.Context) (*memory.FactStoreStats, error) {
|
||||
stats := &memory.FactStoreStats{
|
||||
ByLevel: make(map[memory.HierLevel]int),
|
||||
ByDomain: make(map[string]int),
|
||||
}
|
||||
|
||||
// Total facts.
|
||||
row := r.db.QueryRow(`SELECT COUNT(*) FROM hierarchical_facts`)
|
||||
if err := row.Scan(&stats.TotalFacts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Stale count.
|
||||
row = r.db.QueryRow(`SELECT COUNT(*) FROM hierarchical_facts WHERE is_stale = 1`)
|
||||
if err := row.Scan(&stats.StaleCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// With embeddings.
|
||||
row = r.db.QueryRow(`SELECT COUNT(*) FROM hierarchical_facts WHERE embedding IS NOT NULL`)
|
||||
if err := row.Scan(&stats.WithEmbeddings); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// By level.
|
||||
rows, err := r.db.Query(`SELECT level, COUNT(*) FROM hierarchical_facts GROUP BY level`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var level, count int
|
||||
if err := rows.Scan(&level, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.ByLevel[memory.HierLevel(level)] = count
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// By domain.
|
||||
rows2, err := r.db.Query(`SELECT domain, COUNT(*) FROM hierarchical_facts WHERE domain IS NOT NULL AND domain != '' GROUP BY domain`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows2.Close()
|
||||
for rows2.Next() {
|
||||
var domain string
|
||||
var count int
|
||||
if err := rows2.Scan(&domain, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.ByDomain[domain] = count
|
||||
}
|
||||
|
||||
// Gene count (Genome Layer).
|
||||
row = r.db.QueryRow(`SELECT COUNT(*) FROM hierarchical_facts WHERE source = 'genome'`)
|
||||
if err := row.Scan(&stats.GeneCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// v3.3: Cold count (hit_count=0, created >30 days ago, not gene, not archived).
|
||||
thirtyDaysAgo := time.Now().AddDate(0, 0, -30).Format(timeFormat)
|
||||
row = r.db.QueryRow(`SELECT COUNT(*) FROM hierarchical_facts
|
||||
WHERE hit_count = 0 AND created_at < ? AND source != 'genome' AND is_archived = 0`,
|
||||
thirtyDaysAgo)
|
||||
if err := row.Scan(&stats.ColdCount); err != nil {
|
||||
// Ignore if column doesn't exist yet.
|
||||
stats.ColdCount = 0
|
||||
}
|
||||
|
||||
return stats, rows2.Err()
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func scanFact(row *sql.Row) (*memory.Fact, error) {
|
||||
var f memory.Fact
|
||||
var levelInt int
|
||||
var domain, module, codeRef, parentID, sessionID sql.NullString
|
||||
var embeddingBlob []byte
|
||||
var ttlJSON sql.NullString
|
||||
var createdAt, validFrom string
|
||||
var validUntil sql.NullString
|
||||
var isStale, isArchived int
|
||||
|
||||
err := row.Scan(&f.ID, &f.Content, &levelInt,
|
||||
&domain, &module, &codeRef, &parentID,
|
||||
&embeddingBlob, &ttlJSON,
|
||||
&createdAt, &validFrom, &validUntil,
|
||||
&isStale, &isArchived, &f.Confidence, &f.Source, &sessionID,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("fact not found")
|
||||
}
|
||||
return nil, fmt.Errorf("scan fact: %w", err)
|
||||
}
|
||||
|
||||
f.Level = memory.HierLevel(levelInt)
|
||||
f.Domain = domain.String
|
||||
f.Module = module.String
|
||||
f.CodeRef = codeRef.String
|
||||
f.ParentID = parentID.String
|
||||
f.SessionID = sessionID.String
|
||||
f.IsStale = isStale != 0
|
||||
f.IsArchived = isArchived != 0
|
||||
f.IsGene = f.Source == "genome" // Genome Layer: auto-detect from source
|
||||
|
||||
f.CreatedAt, _ = time.Parse(timeFormat, createdAt)
|
||||
f.ValidFrom, _ = time.Parse(timeFormat, validFrom)
|
||||
f.UpdatedAt = f.CreatedAt // We don't have a separate updated_at column in the DB schema.
|
||||
|
||||
if validUntil.Valid {
|
||||
t, _ := time.Parse(timeFormat, validUntil.String)
|
||||
f.ValidUntil = &t
|
||||
}
|
||||
|
||||
if len(embeddingBlob) > 0 {
|
||||
if err := json.Unmarshal(embeddingBlob, &f.Embedding); err != nil {
|
||||
return nil, fmt.Errorf("decode embedding: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if ttlJSON.Valid && ttlJSON.String != "" {
|
||||
f.TTL = &memory.TTLConfig{}
|
||||
if err := json.Unmarshal([]byte(ttlJSON.String), f.TTL); err != nil {
|
||||
return nil, fmt.Errorf("decode ttl_config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &f, nil
|
||||
}
|
||||
|
||||
func scanFacts(rows *sql.Rows) ([]*memory.Fact, error) {
|
||||
var facts []*memory.Fact
|
||||
for rows.Next() {
|
||||
var f memory.Fact
|
||||
var levelInt int
|
||||
var domain, module, codeRef, parentID, sessionID sql.NullString
|
||||
var embeddingBlob []byte
|
||||
var ttlJSON sql.NullString
|
||||
var createdAt, validFrom string
|
||||
var validUntil sql.NullString
|
||||
var isStale, isArchived int
|
||||
|
||||
err := rows.Scan(&f.ID, &f.Content, &levelInt,
|
||||
&domain, &module, &codeRef, &parentID,
|
||||
&embeddingBlob, &ttlJSON,
|
||||
&createdAt, &validFrom, &validUntil,
|
||||
&isStale, &isArchived, &f.Confidence, &f.Source, &sessionID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scan fact row: %w", err)
|
||||
}
|
||||
|
||||
f.Level = memory.HierLevel(levelInt)
|
||||
f.Domain = domain.String
|
||||
f.Module = module.String
|
||||
f.CodeRef = codeRef.String
|
||||
f.ParentID = parentID.String
|
||||
f.SessionID = sessionID.String
|
||||
f.IsStale = isStale != 0
|
||||
f.IsArchived = isArchived != 0
|
||||
f.IsGene = f.Source == "genome" // Genome Layer: auto-detect from source
|
||||
f.CreatedAt, _ = time.Parse(timeFormat, createdAt)
|
||||
f.ValidFrom, _ = time.Parse(timeFormat, validFrom)
|
||||
f.UpdatedAt = f.CreatedAt
|
||||
|
||||
if validUntil.Valid {
|
||||
t, _ := time.Parse(timeFormat, validUntil.String)
|
||||
f.ValidUntil = &t
|
||||
}
|
||||
|
||||
if len(embeddingBlob) > 0 {
|
||||
if err := json.Unmarshal(embeddingBlob, &f.Embedding); err != nil {
|
||||
return nil, fmt.Errorf("decode embedding: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if ttlJSON.Valid && ttlJSON.String != "" {
|
||||
f.TTL = &memory.TTLConfig{}
|
||||
if err := json.Unmarshal([]byte(ttlJSON.String), f.TTL); err != nil {
|
||||
return nil, fmt.Errorf("decode ttl_config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
facts = append(facts, &f)
|
||||
}
|
||||
return facts, rows.Err()
|
||||
}
|
||||
|
||||
func encodeEmbedding(embedding []float64) ([]byte, error) {
|
||||
if len(embedding) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
data, err := json.Marshal(embedding)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encode embedding: %w", err)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func encodeTTL(ttl *memory.TTLConfig) (*string, error) {
|
||||
if ttl == nil {
|
||||
return nil, nil
|
||||
}
|
||||
data, err := json.Marshal(ttl)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encode ttl: %w", err)
|
||||
}
|
||||
s := string(data)
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
func nullStr(s string) sql.NullString {
|
||||
if s == "" {
|
||||
return sql.NullString{}
|
||||
}
|
||||
return sql.NullString{String: s, Valid: true}
|
||||
}
|
||||
|
||||
func boolToInt(b bool) int {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Ensure FactRepo implements memory.FactStore.
|
||||
var _ memory.FactStore = (*FactRepo)(nil)
|
||||
|
||||
// ListGenes returns all genome facts (immutable survival invariants).
|
||||
func (r *FactRepo) ListGenes(ctx context.Context) ([]*memory.Fact, error) {
|
||||
rows, err := r.db.Query(`SELECT id, content, level, domain, module, code_ref, parent_id,
|
||||
embedding, ttl_config, created_at, valid_from, valid_until,
|
||||
is_stale, is_archived, confidence, source, session_id
|
||||
FROM hierarchical_facts WHERE source = 'genome' ORDER BY created_at ASC`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list genes: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanFacts(rows)
|
||||
}
|
||||
|
||||
// --- v3.3 Context GC ---
|
||||
|
||||
// TouchFact increments hit_count and updates last_accessed_at.
|
||||
func (r *FactRepo) TouchFact(ctx context.Context, id string) error {
|
||||
now := time.Now().Format(timeFormat)
|
||||
_, err := r.db.Exec(`UPDATE hierarchical_facts
|
||||
SET hit_count = COALESCE(hit_count, 0) + 1, last_accessed_at = ?
|
||||
WHERE id = ?`, now, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("touch fact: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetColdFacts returns facts with hit_count=0, created >30 days ago.
|
||||
// Genes (source='genome') and archived facts are excluded.
|
||||
func (r *FactRepo) GetColdFacts(ctx context.Context, limit int) ([]*memory.Fact, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
thirtyDaysAgo := time.Now().AddDate(0, 0, -30).Format(timeFormat)
|
||||
rows, err := r.db.Query(`SELECT id, content, level, domain, module, code_ref, parent_id,
|
||||
embedding, ttl_config, created_at, valid_from, valid_until,
|
||||
is_stale, is_archived, confidence, source, session_id
|
||||
FROM hierarchical_facts
|
||||
WHERE COALESCE(hit_count, 0) = 0
|
||||
AND created_at < ?
|
||||
AND source != 'genome'
|
||||
AND is_archived = 0
|
||||
ORDER BY created_at ASC
|
||||
LIMIT ?`, thirtyDaysAgo, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get cold facts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanFacts(rows)
|
||||
}
|
||||
|
||||
// CompressFacts archives originals and creates a summary fact.
|
||||
// Genes (source='genome') are silently skipped.
|
||||
func (r *FactRepo) CompressFacts(ctx context.Context, ids []string, summary string) (string, error) {
|
||||
if len(ids) == 0 {
|
||||
return "", fmt.Errorf("no fact IDs provided")
|
||||
}
|
||||
if summary == "" {
|
||||
return "", fmt.Errorf("summary text is required")
|
||||
}
|
||||
|
||||
// Determine domain from first non-gene fact.
|
||||
var domain string
|
||||
var level memory.HierLevel
|
||||
for _, id := range ids {
|
||||
f, err := r.Get(ctx, id)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if f.IsGene {
|
||||
continue // skip genes
|
||||
}
|
||||
domain = f.Domain
|
||||
level = f.Level
|
||||
break
|
||||
}
|
||||
|
||||
// Archive originals (skip genes).
|
||||
archived := 0
|
||||
for _, id := range ids {
|
||||
f, err := r.Get(ctx, id)
|
||||
if err != nil || f.IsGene {
|
||||
continue
|
||||
}
|
||||
f.Archive()
|
||||
if err := r.Update(ctx, f); err != nil {
|
||||
return "", fmt.Errorf("archive fact %s: %w", id, err)
|
||||
}
|
||||
archived++
|
||||
}
|
||||
|
||||
if archived == 0 {
|
||||
return "", fmt.Errorf("no facts were archived (all genes or not found)")
|
||||
}
|
||||
|
||||
// Create summary fact.
|
||||
summaryFact := memory.NewFact(summary, level, domain, "")
|
||||
summaryFact.Source = "consolidation"
|
||||
if err := r.Add(ctx, summaryFact); err != nil {
|
||||
return "", fmt.Errorf("create summary fact: %w", err)
|
||||
}
|
||||
|
||||
return summaryFact.ID, nil
|
||||
}
|
||||
293
internal/infrastructure/sqlite/fact_repo_test.go
Normal file
293
internal/infrastructure/sqlite/fact_repo_test.go
Normal file
|
|
@ -0,0 +1,293 @@
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sentinel-community/gomcp/internal/domain/memory"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestFactRepo(t *testing.T) *FactRepo {
|
||||
t.Helper()
|
||||
db, err := OpenMemory()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { db.Close() })
|
||||
|
||||
repo, err := NewFactRepo(db)
|
||||
require.NoError(t, err)
|
||||
return repo
|
||||
}
|
||||
|
||||
func TestFactRepo_Add_Get(t *testing.T) {
|
||||
repo := newTestFactRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
fact := memory.NewFact("Go is fast", memory.LevelProject, "core", "engine")
|
||||
fact.Confidence = 0.95
|
||||
fact.Source = "manual"
|
||||
fact.CodeRef = "main.go:42"
|
||||
|
||||
err := repo.Add(ctx, fact)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := repo.Get(ctx, fact.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
|
||||
assert.Equal(t, fact.ID, got.ID)
|
||||
assert.Equal(t, fact.Content, got.Content)
|
||||
assert.Equal(t, fact.Level, got.Level)
|
||||
assert.Equal(t, fact.Domain, got.Domain)
|
||||
assert.Equal(t, fact.Module, got.Module)
|
||||
assert.Equal(t, fact.CodeRef, got.CodeRef)
|
||||
assert.InDelta(t, fact.Confidence, got.Confidence, 0.001)
|
||||
assert.Equal(t, fact.Source, got.Source)
|
||||
assert.False(t, got.IsStale)
|
||||
assert.False(t, got.IsArchived)
|
||||
}
|
||||
|
||||
func TestFactRepo_Get_NotFound(t *testing.T) {
|
||||
repo := newTestFactRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
got, err := repo.Get(ctx, "nonexistent")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, got)
|
||||
}
|
||||
|
||||
func TestFactRepo_Update(t *testing.T) {
|
||||
repo := newTestFactRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
fact := memory.NewFact("original", memory.LevelProject, "core", "")
|
||||
require.NoError(t, repo.Add(ctx, fact))
|
||||
|
||||
fact.Content = "updated"
|
||||
fact.IsStale = true
|
||||
require.NoError(t, repo.Update(ctx, fact))
|
||||
|
||||
got, err := repo.Get(ctx, fact.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "updated", got.Content)
|
||||
assert.True(t, got.IsStale)
|
||||
}
|
||||
|
||||
func TestFactRepo_Delete(t *testing.T) {
|
||||
repo := newTestFactRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
fact := memory.NewFact("to delete", memory.LevelProject, "", "")
|
||||
require.NoError(t, repo.Add(ctx, fact))
|
||||
|
||||
err := repo.Delete(ctx, fact.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := repo.Get(ctx, fact.ID)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, got)
|
||||
}
|
||||
|
||||
func TestFactRepo_ListByDomain(t *testing.T) {
|
||||
repo := newTestFactRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
f1 := memory.NewFact("fact1", memory.LevelProject, "backend", "")
|
||||
f2 := memory.NewFact("fact2", memory.LevelDomain, "backend", "")
|
||||
f3 := memory.NewFact("fact3", memory.LevelProject, "frontend", "")
|
||||
f4 := memory.NewFact("stale", memory.LevelProject, "backend", "")
|
||||
f4.IsStale = true
|
||||
|
||||
for _, f := range []*memory.Fact{f1, f2, f3, f4} {
|
||||
require.NoError(t, repo.Add(ctx, f))
|
||||
}
|
||||
|
||||
// Without stale
|
||||
facts, err := repo.ListByDomain(ctx, "backend", false)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, facts, 2)
|
||||
|
||||
// With stale
|
||||
facts, err = repo.ListByDomain(ctx, "backend", true)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, facts, 3)
|
||||
}
|
||||
|
||||
func TestFactRepo_ListByLevel(t *testing.T) {
|
||||
repo := newTestFactRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
f1 := memory.NewFact("f1", memory.LevelProject, "", "")
|
||||
f2 := memory.NewFact("f2", memory.LevelProject, "", "")
|
||||
f3 := memory.NewFact("f3", memory.LevelDomain, "", "")
|
||||
|
||||
for _, f := range []*memory.Fact{f1, f2, f3} {
|
||||
require.NoError(t, repo.Add(ctx, f))
|
||||
}
|
||||
|
||||
facts, err := repo.ListByLevel(ctx, memory.LevelProject)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, facts, 2)
|
||||
}
|
||||
|
||||
func TestFactRepo_ListDomains(t *testing.T) {
|
||||
repo := newTestFactRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
f1 := memory.NewFact("f1", memory.LevelProject, "backend", "")
|
||||
f2 := memory.NewFact("f2", memory.LevelProject, "frontend", "")
|
||||
f3 := memory.NewFact("f3", memory.LevelProject, "backend", "")
|
||||
|
||||
for _, f := range []*memory.Fact{f1, f2, f3} {
|
||||
require.NoError(t, repo.Add(ctx, f))
|
||||
}
|
||||
|
||||
domains, err := repo.ListDomains(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, domains, 2)
|
||||
assert.Contains(t, domains, "backend")
|
||||
assert.Contains(t, domains, "frontend")
|
||||
}
|
||||
|
||||
func TestFactRepo_GetStale(t *testing.T) {
|
||||
repo := newTestFactRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
f1 := memory.NewFact("fresh", memory.LevelProject, "", "")
|
||||
f2 := memory.NewFact("stale", memory.LevelProject, "", "")
|
||||
f2.IsStale = true
|
||||
f3 := memory.NewFact("archived", memory.LevelProject, "", "")
|
||||
f3.IsStale = true
|
||||
f3.IsArchived = true
|
||||
|
||||
for _, f := range []*memory.Fact{f1, f2, f3} {
|
||||
require.NoError(t, repo.Add(ctx, f))
|
||||
}
|
||||
|
||||
// Without archived
|
||||
stale, err := repo.GetStale(ctx, false)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, stale, 1)
|
||||
|
||||
// With archived
|
||||
stale, err = repo.GetStale(ctx, true)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, stale, 2)
|
||||
}
|
||||
|
||||
func TestFactRepo_Search(t *testing.T) {
|
||||
repo := newTestFactRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
f1 := memory.NewFact("Go concurrency patterns", memory.LevelProject, "", "")
|
||||
f2 := memory.NewFact("Python is slow", memory.LevelProject, "", "")
|
||||
f3 := memory.NewFact("Go channels are great", memory.LevelDomain, "", "")
|
||||
|
||||
for _, f := range []*memory.Fact{f1, f2, f3} {
|
||||
require.NoError(t, repo.Add(ctx, f))
|
||||
}
|
||||
|
||||
results, err := repo.Search(ctx, "Go", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
}
|
||||
|
||||
func TestFactRepo_GetExpired(t *testing.T) {
|
||||
repo := newTestFactRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
f1 := memory.NewFact("no ttl", memory.LevelProject, "", "")
|
||||
|
||||
f2 := memory.NewFact("expired", memory.LevelProject, "", "")
|
||||
f2.TTL = &memory.TTLConfig{TTLSeconds: 1, OnExpire: memory.OnExpireMarkStale}
|
||||
f2.CreatedAt = time.Now().Add(-2 * time.Hour)
|
||||
f2.ValidFrom = f2.CreatedAt
|
||||
|
||||
for _, f := range []*memory.Fact{f1, f2} {
|
||||
require.NoError(t, repo.Add(ctx, f))
|
||||
}
|
||||
|
||||
expired, err := repo.GetExpired(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, expired, 1)
|
||||
assert.Equal(t, f2.ID, expired[0].ID)
|
||||
}
|
||||
|
||||
func TestFactRepo_RefreshTTL(t *testing.T) {
|
||||
repo := newTestFactRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
f := memory.NewFact("refreshable", memory.LevelProject, "", "")
|
||||
f.TTL = &memory.TTLConfig{TTLSeconds: 3600, OnExpire: memory.OnExpireMarkStale}
|
||||
f.CreatedAt = time.Now().Add(-2 * time.Hour)
|
||||
f.ValidFrom = f.CreatedAt
|
||||
require.NoError(t, repo.Add(ctx, f))
|
||||
|
||||
require.NoError(t, repo.RefreshTTL(ctx, f.ID))
|
||||
|
||||
got, err := repo.Get(ctx, f.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, got.CreatedAt.After(f.CreatedAt))
|
||||
}
|
||||
|
||||
func TestFactRepo_Stats(t *testing.T) {
|
||||
repo := newTestFactRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
f1 := memory.NewFact("f1", memory.LevelProject, "backend", "")
|
||||
f2 := memory.NewFact("f2", memory.LevelDomain, "backend", "")
|
||||
f2.IsStale = true
|
||||
f3 := memory.NewFact("f3", memory.LevelProject, "frontend", "")
|
||||
f3.Embedding = []float64{0.1, 0.2}
|
||||
|
||||
for _, f := range []*memory.Fact{f1, f2, f3} {
|
||||
require.NoError(t, repo.Add(ctx, f))
|
||||
}
|
||||
|
||||
stats, err := repo.Stats(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, stats.TotalFacts)
|
||||
assert.Equal(t, 2, stats.ByLevel[memory.LevelProject])
|
||||
assert.Equal(t, 1, stats.ByLevel[memory.LevelDomain])
|
||||
assert.Equal(t, 2, stats.ByDomain["backend"])
|
||||
assert.Equal(t, 1, stats.ByDomain["frontend"])
|
||||
assert.Equal(t, 1, stats.StaleCount)
|
||||
assert.Equal(t, 1, stats.WithEmbeddings)
|
||||
}
|
||||
|
||||
func TestFactRepo_EmbeddingRoundTrip(t *testing.T) {
|
||||
repo := newTestFactRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
f := memory.NewFact("with embedding", memory.LevelProject, "", "")
|
||||
f.Embedding = []float64{0.1, 0.2, 0.3, -0.5}
|
||||
require.NoError(t, repo.Add(ctx, f))
|
||||
|
||||
got, err := repo.Get(ctx, f.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, got.Embedding, 4)
|
||||
assert.InDelta(t, 0.1, got.Embedding[0], 0.0001)
|
||||
assert.InDelta(t, -0.5, got.Embedding[3], 0.0001)
|
||||
}
|
||||
|
||||
func TestFactRepo_TTLConfigRoundTrip(t *testing.T) {
|
||||
repo := newTestFactRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
f := memory.NewFact("with ttl", memory.LevelProject, "", "")
|
||||
f.TTL = &memory.TTLConfig{
|
||||
TTLSeconds: 3600,
|
||||
RefreshTrigger: "main.go",
|
||||
OnExpire: memory.OnExpireArchive,
|
||||
}
|
||||
require.NoError(t, repo.Add(ctx, f))
|
||||
|
||||
got, err := repo.Get(ctx, f.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got.TTL)
|
||||
assert.Equal(t, 3600, got.TTL.TTLSeconds)
|
||||
assert.Equal(t, "main.go", got.TTL.RefreshTrigger)
|
||||
assert.Equal(t, memory.OnExpireArchive, got.TTL.OnExpire)
|
||||
}
|
||||
128
internal/infrastructure/sqlite/interaction_repo.go
Normal file
128
internal/infrastructure/sqlite/interaction_repo.go
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// InteractionEntry represents a single tool call record in the interaction log.
|
||||
type InteractionEntry struct {
|
||||
ID int64 `json:"id"`
|
||||
ToolName string `json:"tool_name"`
|
||||
ArgsJSON string `json:"args_json,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Processed bool `json:"processed"`
|
||||
}
|
||||
|
||||
// InteractionLogRepo provides crash-safe tool call recording in SQLite.
|
||||
// Every tool call is INSERT-ed immediately; WAL mode ensures durability
|
||||
// even on kill -9 / terminal close.
|
||||
type InteractionLogRepo struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
// NewInteractionLogRepo creates the interaction_log table if needed and returns the repo.
|
||||
func NewInteractionLogRepo(db *DB) (*InteractionLogRepo, error) {
|
||||
createSQL := `
|
||||
CREATE TABLE IF NOT EXISTS interaction_log (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
tool_name TEXT NOT NULL,
|
||||
args_json TEXT,
|
||||
timestamp TEXT NOT NULL,
|
||||
processed INTEGER DEFAULT 0
|
||||
)`
|
||||
if _, err := db.Exec(createSQL); err != nil {
|
||||
return nil, fmt.Errorf("create interaction_log table: %w", err)
|
||||
}
|
||||
return &InteractionLogRepo{db: db}, nil
|
||||
}
|
||||
|
||||
// Record inserts a tool call entry. This is designed to be fire-and-forget
|
||||
// from the middleware — errors are logged but don't break the tool call.
|
||||
func (r *InteractionLogRepo) Record(ctx context.Context, toolName string, args map[string]interface{}) error {
|
||||
argsJSON := ""
|
||||
if len(args) > 0 {
|
||||
// Only keep string arguments to reduce noise
|
||||
filtered := make(map[string]string)
|
||||
for k, v := range args {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
// Truncate very long values
|
||||
if len(s) > 200 {
|
||||
s = s[:200] + "..."
|
||||
}
|
||||
filtered[k] = s
|
||||
}
|
||||
}
|
||||
if len(filtered) > 0 {
|
||||
data, _ := json.Marshal(filtered)
|
||||
argsJSON = string(data)
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
_, err := r.db.Exec(
|
||||
`INSERT INTO interaction_log (tool_name, args_json, timestamp) VALUES (?, ?, ?)`,
|
||||
toolName, argsJSON, now,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetUnprocessed returns all entries not yet processed, ordered oldest first.
|
||||
func (r *InteractionLogRepo) GetUnprocessed(ctx context.Context) ([]InteractionEntry, error) {
|
||||
rows, err := r.db.Query(
|
||||
`SELECT id, tool_name, args_json, timestamp, processed
|
||||
FROM interaction_log WHERE processed = 0 ORDER BY id ASC`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query unprocessed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var entries []InteractionEntry
|
||||
for rows.Next() {
|
||||
var e InteractionEntry
|
||||
var ts string
|
||||
var proc int
|
||||
if err := rows.Scan(&e.ID, &e.ToolName, &e.ArgsJSON, &ts, &proc); err != nil {
|
||||
return nil, fmt.Errorf("scan entry: %w", err)
|
||||
}
|
||||
e.Timestamp, _ = time.Parse(time.RFC3339, ts)
|
||||
e.Processed = proc != 0
|
||||
entries = append(entries, e)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
// MarkProcessed marks entries as processed by their IDs.
|
||||
func (r *InteractionLogRepo) MarkProcessed(ctx context.Context, ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, id := range ids {
|
||||
if _, err := r.db.Exec(`UPDATE interaction_log SET processed = 1 WHERE id = ?`, id); err != nil {
|
||||
return fmt.Errorf("mark processed id=%d: %w", id, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count returns the total number of entries and unprocessed count.
|
||||
func (r *InteractionLogRepo) Count(ctx context.Context) (total int, unprocessed int, err error) {
|
||||
row := r.db.QueryRow(`SELECT COUNT(*), COALESCE(SUM(CASE WHEN processed=0 THEN 1 ELSE 0 END), 0) FROM interaction_log`)
|
||||
err = row.Scan(&total, &unprocessed)
|
||||
return
|
||||
}
|
||||
|
||||
// Prune deletes processed entries older than the given duration.
|
||||
func (r *InteractionLogRepo) Prune(ctx context.Context, olderThan time.Duration) (int64, error) {
|
||||
cutoff := time.Now().UTC().Add(-olderThan).Format(time.RFC3339)
|
||||
result, err := r.db.Exec(
|
||||
`DELETE FROM interaction_log WHERE processed = 1 AND timestamp <= ?`, cutoff,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
186
internal/infrastructure/sqlite/interaction_repo_test.go
Normal file
186
internal/infrastructure/sqlite/interaction_repo_test.go
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func setupInteractionRepo(t *testing.T) *InteractionLogRepo {
|
||||
t.Helper()
|
||||
db, err := OpenMemory()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { db.Close() })
|
||||
repo, err := NewInteractionLogRepo(db)
|
||||
require.NoError(t, err)
|
||||
return repo
|
||||
}
|
||||
|
||||
func TestNewInteractionLogRepo(t *testing.T) {
|
||||
repo := setupInteractionRepo(t)
|
||||
require.NotNil(t, repo)
|
||||
}
|
||||
|
||||
func TestInteractionLogRepo_Record(t *testing.T) {
|
||||
repo := setupInteractionRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
err := repo.Record(ctx, "add_fact", map[string]interface{}{
|
||||
"content": "test fact",
|
||||
"level": 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
total, unproc, err := repo.Count(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, total)
|
||||
assert.Equal(t, 1, unproc)
|
||||
}
|
||||
|
||||
func TestInteractionLogRepo_Record_EmptyArgs(t *testing.T) {
|
||||
repo := setupInteractionRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
err := repo.Record(ctx, "health", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
entries, err := repo.GetUnprocessed(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 1)
|
||||
assert.Equal(t, "health", entries[0].ToolName)
|
||||
assert.Empty(t, entries[0].ArgsJSON)
|
||||
}
|
||||
|
||||
func TestInteractionLogRepo_Record_TruncatesLongValues(t *testing.T) {
|
||||
repo := setupInteractionRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
longContent := make([]byte, 500)
|
||||
for i := range longContent {
|
||||
longContent[i] = 'x'
|
||||
}
|
||||
|
||||
err := repo.Record(ctx, "add_fact", map[string]interface{}{
|
||||
"content": string(longContent),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
entries, err := repo.GetUnprocessed(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 1)
|
||||
// args_json should contain truncated value (200 chars + "...")
|
||||
assert.Contains(t, entries[0].ArgsJSON, "...")
|
||||
assert.Less(t, len(entries[0].ArgsJSON), 300)
|
||||
}
|
||||
|
||||
func TestInteractionLogRepo_GetUnprocessed(t *testing.T) {
|
||||
repo := setupInteractionRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_ = repo.Record(ctx, "add_fact", map[string]interface{}{"content": "a"})
|
||||
_ = repo.Record(ctx, "search_facts", map[string]interface{}{"query": "b"})
|
||||
_ = repo.Record(ctx, "health", nil)
|
||||
|
||||
entries, err := repo.GetUnprocessed(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, entries, 3)
|
||||
// Ordered by id ASC
|
||||
assert.Equal(t, "add_fact", entries[0].ToolName)
|
||||
assert.Equal(t, "search_facts", entries[1].ToolName)
|
||||
assert.Equal(t, "health", entries[2].ToolName)
|
||||
}
|
||||
|
||||
func TestInteractionLogRepo_MarkProcessed(t *testing.T) {
|
||||
repo := setupInteractionRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_ = repo.Record(ctx, "tool_a", nil)
|
||||
_ = repo.Record(ctx, "tool_b", nil)
|
||||
|
||||
entries, _ := repo.GetUnprocessed(ctx)
|
||||
require.Len(t, entries, 2)
|
||||
|
||||
// Mark first as processed
|
||||
err := repo.MarkProcessed(ctx, []int64{entries[0].ID})
|
||||
require.NoError(t, err)
|
||||
|
||||
remaining, _ := repo.GetUnprocessed(ctx)
|
||||
assert.Len(t, remaining, 1)
|
||||
assert.Equal(t, "tool_b", remaining[0].ToolName)
|
||||
|
||||
total, unproc, _ := repo.Count(ctx)
|
||||
assert.Equal(t, 2, total)
|
||||
assert.Equal(t, 1, unproc)
|
||||
}
|
||||
|
||||
func TestInteractionLogRepo_MarkProcessed_Empty(t *testing.T) {
|
||||
repo := setupInteractionRepo(t)
|
||||
err := repo.MarkProcessed(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestInteractionLogRepo_Prune(t *testing.T) {
|
||||
repo := setupInteractionRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Insert and mark as processed
|
||||
_ = repo.Record(ctx, "old_tool", nil)
|
||||
entries, _ := repo.GetUnprocessed(ctx)
|
||||
_ = repo.MarkProcessed(ctx, []int64{entries[0].ID})
|
||||
|
||||
// Prune with 0 duration should delete all processed
|
||||
deleted, err := repo.Prune(ctx, 0)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), deleted)
|
||||
|
||||
total, _, _ := repo.Count(ctx)
|
||||
assert.Equal(t, 0, total)
|
||||
}
|
||||
|
||||
func TestInteractionLogRepo_Prune_KeepsUnprocessed(t *testing.T) {
|
||||
repo := setupInteractionRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_ = repo.Record(ctx, "unprocessed_tool", nil)
|
||||
|
||||
// Prune should not delete unprocessed entries
|
||||
deleted, err := repo.Prune(ctx, 0)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), deleted)
|
||||
|
||||
total, _, _ := repo.Count(ctx)
|
||||
assert.Equal(t, 1, total)
|
||||
}
|
||||
|
||||
func TestInteractionLogRepo_Timestamps(t *testing.T) {
|
||||
repo := setupInteractionRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
before := time.Now().UTC()
|
||||
_ = repo.Record(ctx, "timed_tool", nil)
|
||||
after := time.Now().UTC()
|
||||
|
||||
entries, _ := repo.GetUnprocessed(ctx)
|
||||
require.Len(t, entries, 1)
|
||||
|
||||
ts := entries[0].Timestamp
|
||||
assert.False(t, ts.Before(before.Add(-time.Second)))
|
||||
assert.False(t, ts.After(after.Add(time.Second)))
|
||||
}
|
||||
|
||||
func TestInteractionLogRepo_MultipleRecords_Count(t *testing.T) {
|
||||
repo := setupInteractionRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
_ = repo.Record(ctx, "tool", nil)
|
||||
}
|
||||
|
||||
total, unproc, err := repo.Count(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 10, total)
|
||||
assert.Equal(t, 10, unproc)
|
||||
}
|
||||
94
internal/infrastructure/sqlite/peer_repo.go
Normal file
94
internal/infrastructure/sqlite/peer_repo.go
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/sentinel-community/gomcp/internal/domain/peer"
|
||||
)
|
||||
|
||||
// PeerRepo implements peer.PeerStore using SQLite.
|
||||
type PeerRepo struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
// NewPeerRepo creates a PeerRepo and ensures the peers table exists.
|
||||
func NewPeerRepo(db *DB) (*PeerRepo, error) {
|
||||
repo := &PeerRepo{db: db}
|
||||
if err := repo.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("peer repo migrate: %w", err)
|
||||
}
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
func (r *PeerRepo) migrate() error {
|
||||
stmt := `CREATE TABLE IF NOT EXISTS peers (
|
||||
peer_id TEXT PRIMARY KEY,
|
||||
node_name TEXT NOT NULL,
|
||||
genome_hash TEXT NOT NULL,
|
||||
trust_level INTEGER DEFAULT 0,
|
||||
last_seen TEXT NOT NULL,
|
||||
fact_count INTEGER DEFAULT 0,
|
||||
handshake_at TEXT,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
)`
|
||||
if _, err := r.db.Exec(stmt); err != nil {
|
||||
return fmt.Errorf("create peers table: %w", err)
|
||||
}
|
||||
_, _ = r.db.Exec(`CREATE INDEX IF NOT EXISTS idx_peers_trust ON peers(trust_level)`)
|
||||
_, _ = r.db.Exec(`CREATE INDEX IF NOT EXISTS idx_peers_last_seen ON peers(last_seen)`)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SavePeer upserts a peer record.
|
||||
func (r *PeerRepo) SavePeer(_ context.Context, p *peer.PeerInfo) error {
|
||||
stmt := `INSERT OR REPLACE INTO peers (peer_id, node_name, genome_hash, trust_level, last_seen, fact_count, handshake_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)`
|
||||
hsAt := ""
|
||||
if !p.HandshakeAt.IsZero() {
|
||||
hsAt = p.HandshakeAt.Format(timeFormat)
|
||||
}
|
||||
_, err := r.db.Exec(stmt,
|
||||
p.PeerID, p.NodeName, p.GenomeHash, int(p.Trust),
|
||||
p.LastSeen.Format(timeFormat), p.FactCount, hsAt,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// LoadPeers returns all stored peers.
|
||||
func (r *PeerRepo) LoadPeers(_ context.Context) ([]*peer.PeerInfo, error) {
|
||||
rows, err := r.db.Query(`SELECT peer_id, node_name, genome_hash, trust_level, last_seen, fact_count, handshake_at FROM peers`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var peers []*peer.PeerInfo
|
||||
for rows.Next() {
|
||||
var p peer.PeerInfo
|
||||
var trustInt int
|
||||
var lastSeenStr, handshakeStr string
|
||||
if err := rows.Scan(&p.PeerID, &p.NodeName, &p.GenomeHash, &trustInt, &lastSeenStr, &p.FactCount, &handshakeStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.Trust = peer.TrustLevel(trustInt)
|
||||
p.LastSeen, _ = time.Parse(timeFormat, lastSeenStr)
|
||||
if handshakeStr != "" {
|
||||
p.HandshakeAt, _ = time.Parse(timeFormat, handshakeStr)
|
||||
}
|
||||
peers = append(peers, &p)
|
||||
}
|
||||
return peers, rows.Err()
|
||||
}
|
||||
|
||||
// DeleteExpired removes peers not seen within the given duration.
|
||||
func (r *PeerRepo) DeleteExpired(_ context.Context, olderThan time.Duration) (int, error) {
|
||||
cutoff := time.Now().Add(-olderThan).Format(timeFormat)
|
||||
result, err := r.db.Exec(`DELETE FROM peers WHERE last_seen < ?`, cutoff)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
return int(n), nil
|
||||
}
|
||||
366
internal/infrastructure/sqlite/soc_repo.go
Normal file
366
internal/infrastructure/sqlite/soc_repo.go
Normal file
|
|
@ -0,0 +1,366 @@
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/sentinel-community/gomcp/internal/domain/soc"
|
||||
)
|
||||
|
||||
// SOCRepo provides SQLite persistence for SOC events, incidents, and sensors.
|
||||
type SOCRepo struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
// NewSOCRepo creates and initializes SOC tables.
|
||||
func NewSOCRepo(db *DB) (*SOCRepo, error) {
|
||||
repo := &SOCRepo{db: db}
|
||||
if err := repo.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("soc_repo: migrate: %w", err)
|
||||
}
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
func (r *SOCRepo) migrate() error {
|
||||
tables := []string{
|
||||
`CREATE TABLE IF NOT EXISTS soc_events (
|
||||
id TEXT PRIMARY KEY,
|
||||
source TEXT NOT NULL,
|
||||
sensor_id TEXT NOT NULL DEFAULT '',
|
||||
severity TEXT NOT NULL,
|
||||
category TEXT NOT NULL,
|
||||
subcategory TEXT NOT NULL DEFAULT '',
|
||||
confidence REAL NOT NULL DEFAULT 0.0,
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
session_id TEXT NOT NULL DEFAULT '',
|
||||
decision_hash TEXT NOT NULL DEFAULT '',
|
||||
verdict TEXT NOT NULL DEFAULT 'REVIEW',
|
||||
timestamp TEXT NOT NULL,
|
||||
metadata TEXT NOT NULL DEFAULT '{}'
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS soc_incidents (
|
||||
id TEXT PRIMARY KEY,
|
||||
status TEXT NOT NULL DEFAULT 'OPEN',
|
||||
severity TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
event_ids TEXT NOT NULL DEFAULT '[]',
|
||||
event_count INTEGER NOT NULL DEFAULT 0,
|
||||
decision_chain_anchor TEXT NOT NULL DEFAULT '',
|
||||
chain_length INTEGER NOT NULL DEFAULT 0,
|
||||
correlation_rule TEXT NOT NULL DEFAULT '',
|
||||
kill_chain_phase TEXT NOT NULL DEFAULT '',
|
||||
mitre_mapping TEXT NOT NULL DEFAULT '[]',
|
||||
playbook_applied TEXT NOT NULL DEFAULT '',
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
resolved_at TEXT
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS soc_sensors (
|
||||
sensor_id TEXT PRIMARY KEY,
|
||||
sensor_type TEXT NOT NULL,
|
||||
status TEXT DEFAULT 'UNKNOWN',
|
||||
first_seen TEXT NOT NULL,
|
||||
last_seen TEXT NOT NULL,
|
||||
event_count INTEGER DEFAULT 0,
|
||||
missed_heartbeats INTEGER DEFAULT 0,
|
||||
hostname TEXT NOT NULL DEFAULT '',
|
||||
version TEXT NOT NULL DEFAULT ''
|
||||
)`,
|
||||
// Indexes for common queries.
|
||||
`CREATE INDEX IF NOT EXISTS idx_soc_events_timestamp ON soc_events(timestamp)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_soc_events_severity ON soc_events(severity)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_soc_events_category ON soc_events(category)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_soc_events_sensor ON soc_events(sensor_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_soc_incidents_status ON soc_incidents(status)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_soc_sensors_status ON soc_sensors(status)`,
|
||||
}
|
||||
for _, ddl := range tables {
|
||||
if _, err := r.db.Exec(ddl); err != nil {
|
||||
return fmt.Errorf("exec %q: %w", ddl[:40], err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// === Events ===
|
||||
|
||||
// InsertEvent persists a SOC event.
|
||||
func (r *SOCRepo) InsertEvent(e soc.SOCEvent) error {
|
||||
_, err := r.db.Exec(
|
||||
`INSERT INTO soc_events (id, source, sensor_id, severity, category, subcategory,
|
||||
confidence, description, session_id, decision_hash, verdict, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
e.ID, e.Source, e.SensorID, e.Severity, e.Category, e.Subcategory,
|
||||
e.Confidence, e.Description, e.SessionID, e.DecisionHash, e.Verdict,
|
||||
e.Timestamp.Format(time.RFC3339Nano),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// ListEvents returns events ordered by timestamp (newest first), with limit.
|
||||
func (r *SOCRepo) ListEvents(limit int) ([]soc.SOCEvent, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
rows, err := r.db.Query(
|
||||
`SELECT id, source, sensor_id, severity, category, subcategory,
|
||||
confidence, description, session_id, decision_hash, verdict, timestamp
|
||||
FROM soc_events ORDER BY timestamp DESC LIMIT ?`, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanEvents(rows)
|
||||
}
|
||||
|
||||
// ListEventsByCategory returns events filtered by category.
|
||||
func (r *SOCRepo) ListEventsByCategory(category string, limit int) ([]soc.SOCEvent, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
rows, err := r.db.Query(
|
||||
`SELECT id, source, sensor_id, severity, category, subcategory,
|
||||
confidence, description, session_id, decision_hash, verdict, timestamp
|
||||
FROM soc_events WHERE category = ? ORDER BY timestamp DESC LIMIT ?`,
|
||||
category, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanEvents(rows)
|
||||
}
|
||||
|
||||
// CountEvents returns total event count.
|
||||
func (r *SOCRepo) CountEvents() (int, error) {
|
||||
var count int
|
||||
err := r.db.QueryRow("SELECT COUNT(*) FROM soc_events").Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
// CountEventsSince returns events in the given time window.
|
||||
func (r *SOCRepo) CountEventsSince(since time.Time) (int, error) {
|
||||
var count int
|
||||
err := r.db.QueryRow(
|
||||
"SELECT COUNT(*) FROM soc_events WHERE timestamp >= ?",
|
||||
since.Format(time.RFC3339Nano),
|
||||
).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func scanEvents(rows *sql.Rows) ([]soc.SOCEvent, error) {
|
||||
var events []soc.SOCEvent
|
||||
for rows.Next() {
|
||||
var e soc.SOCEvent
|
||||
var ts string
|
||||
err := rows.Scan(&e.ID, &e.Source, &e.SensorID, &e.Severity,
|
||||
&e.Category, &e.Subcategory, &e.Confidence, &e.Description,
|
||||
&e.SessionID, &e.DecisionHash, &e.Verdict, &ts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.Timestamp, _ = time.Parse(time.RFC3339Nano, ts)
|
||||
events = append(events, e)
|
||||
}
|
||||
return events, rows.Err()
|
||||
}
|
||||
|
||||
// === Incidents ===
|
||||
|
||||
// InsertIncident persists a new incident.
|
||||
func (r *SOCRepo) InsertIncident(inc soc.Incident) error {
|
||||
_, err := r.db.Exec(
|
||||
`INSERT INTO soc_incidents (id, status, severity, title, description,
|
||||
event_count, decision_chain_anchor, chain_length, correlation_rule,
|
||||
kill_chain_phase, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
inc.ID, inc.Status, inc.Severity, inc.Title, inc.Description,
|
||||
inc.EventCount, inc.DecisionChainAnchor, inc.ChainLength,
|
||||
inc.CorrelationRule, inc.KillChainPhase,
|
||||
inc.CreatedAt.Format(time.RFC3339Nano),
|
||||
inc.UpdatedAt.Format(time.RFC3339Nano),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetIncident retrieves an incident by ID.
|
||||
func (r *SOCRepo) GetIncident(id string) (*soc.Incident, error) {
|
||||
var inc soc.Incident
|
||||
var createdAt, updatedAt string
|
||||
var resolvedAt sql.NullString
|
||||
err := r.db.QueryRow(
|
||||
`SELECT id, status, severity, title, description, event_count,
|
||||
decision_chain_anchor, chain_length, correlation_rule,
|
||||
kill_chain_phase, playbook_applied, created_at, updated_at, resolved_at
|
||||
FROM soc_incidents WHERE id = ?`, id,
|
||||
).Scan(&inc.ID, &inc.Status, &inc.Severity, &inc.Title, &inc.Description,
|
||||
&inc.EventCount, &inc.DecisionChainAnchor, &inc.ChainLength,
|
||||
&inc.CorrelationRule, &inc.KillChainPhase, &inc.PlaybookApplied,
|
||||
&createdAt, &updatedAt, &resolvedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inc.CreatedAt, _ = time.Parse(time.RFC3339Nano, createdAt)
|
||||
inc.UpdatedAt, _ = time.Parse(time.RFC3339Nano, updatedAt)
|
||||
if resolvedAt.Valid {
|
||||
t, _ := time.Parse(time.RFC3339Nano, resolvedAt.String)
|
||||
inc.ResolvedAt = &t
|
||||
}
|
||||
return &inc, nil
|
||||
}
|
||||
|
||||
// ListIncidents returns incidents, optionally filtered by status.
|
||||
func (r *SOCRepo) ListIncidents(status string, limit int) ([]soc.Incident, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if status != "" {
|
||||
rows, err = r.db.Query(
|
||||
`SELECT id, status, severity, title, description, event_count,
|
||||
decision_chain_anchor, chain_length, correlation_rule,
|
||||
kill_chain_phase, playbook_applied, created_at, updated_at
|
||||
FROM soc_incidents WHERE status = ? ORDER BY created_at DESC LIMIT ?`,
|
||||
status, limit)
|
||||
} else {
|
||||
rows, err = r.db.Query(
|
||||
`SELECT id, status, severity, title, description, event_count,
|
||||
decision_chain_anchor, chain_length, correlation_rule,
|
||||
kill_chain_phase, playbook_applied, created_at, updated_at
|
||||
FROM soc_incidents ORDER BY created_at DESC LIMIT ?`, limit)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var incidents []soc.Incident
|
||||
for rows.Next() {
|
||||
var inc soc.Incident
|
||||
var createdAt, updatedAt string
|
||||
err := rows.Scan(&inc.ID, &inc.Status, &inc.Severity, &inc.Title,
|
||||
&inc.Description, &inc.EventCount, &inc.DecisionChainAnchor,
|
||||
&inc.ChainLength, &inc.CorrelationRule, &inc.KillChainPhase,
|
||||
&inc.PlaybookApplied, &createdAt, &updatedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inc.CreatedAt, _ = time.Parse(time.RFC3339Nano, createdAt)
|
||||
inc.UpdatedAt, _ = time.Parse(time.RFC3339Nano, updatedAt)
|
||||
incidents = append(incidents, inc)
|
||||
}
|
||||
return incidents, rows.Err()
|
||||
}
|
||||
|
||||
// UpdateIncidentStatus updates status (and optionally resolved_at).
|
||||
func (r *SOCRepo) UpdateIncidentStatus(id string, status soc.IncidentStatus) error {
|
||||
now := time.Now().Format(time.RFC3339Nano)
|
||||
if status == soc.StatusResolved || status == soc.StatusFalsePositive {
|
||||
_, err := r.db.Exec(
|
||||
`UPDATE soc_incidents SET status = ?, updated_at = ?, resolved_at = ? WHERE id = ?`,
|
||||
status, now, now, id)
|
||||
return err
|
||||
}
|
||||
_, err := r.db.Exec(
|
||||
`UPDATE soc_incidents SET status = ?, updated_at = ? WHERE id = ?`,
|
||||
status, now, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// CountOpenIncidents returns count of non-resolved incidents.
|
||||
func (r *SOCRepo) CountOpenIncidents() (int, error) {
|
||||
var count int
|
||||
err := r.db.QueryRow(
|
||||
"SELECT COUNT(*) FROM soc_incidents WHERE status IN ('OPEN', 'INVESTIGATING')",
|
||||
).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
// === Sensors ===
|
||||
|
||||
// UpsertSensor creates or updates a sensor entry.
|
||||
func (r *SOCRepo) UpsertSensor(s soc.Sensor) error {
|
||||
_, err := r.db.Exec(
|
||||
`INSERT INTO soc_sensors (sensor_id, sensor_type, status, first_seen, last_seen,
|
||||
event_count, missed_heartbeats, hostname, version)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(sensor_id) DO UPDATE SET
|
||||
status = excluded.status,
|
||||
last_seen = excluded.last_seen,
|
||||
event_count = excluded.event_count,
|
||||
missed_heartbeats = excluded.missed_heartbeats`,
|
||||
s.SensorID, s.SensorType, s.Status,
|
||||
s.FirstSeen.Format(time.RFC3339Nano),
|
||||
s.LastSeen.Format(time.RFC3339Nano),
|
||||
s.EventCount, s.MissedHeartbeats, s.Hostname, s.Version,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetSensor retrieves a sensor by ID.
|
||||
func (r *SOCRepo) GetSensor(id string) (*soc.Sensor, error) {
|
||||
var s soc.Sensor
|
||||
var firstSeen, lastSeen string
|
||||
err := r.db.QueryRow(
|
||||
`SELECT sensor_id, sensor_type, status, first_seen, last_seen,
|
||||
event_count, missed_heartbeats, hostname, version
|
||||
FROM soc_sensors WHERE sensor_id = ?`, id,
|
||||
).Scan(&s.SensorID, &s.SensorType, &s.Status, &firstSeen, &lastSeen,
|
||||
&s.EventCount, &s.MissedHeartbeats, &s.Hostname, &s.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.FirstSeen, _ = time.Parse(time.RFC3339Nano, firstSeen)
|
||||
s.LastSeen, _ = time.Parse(time.RFC3339Nano, lastSeen)
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
// ListSensors returns all registered sensors.
|
||||
func (r *SOCRepo) ListSensors() ([]soc.Sensor, error) {
|
||||
rows, err := r.db.Query(
|
||||
`SELECT sensor_id, sensor_type, status, first_seen, last_seen,
|
||||
event_count, missed_heartbeats, hostname, version
|
||||
FROM soc_sensors ORDER BY last_seen DESC`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sensors []soc.Sensor
|
||||
for rows.Next() {
|
||||
var s soc.Sensor
|
||||
var firstSeen, lastSeen string
|
||||
err := rows.Scan(&s.SensorID, &s.SensorType, &s.Status,
|
||||
&firstSeen, &lastSeen, &s.EventCount, &s.MissedHeartbeats,
|
||||
&s.Hostname, &s.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.FirstSeen, _ = time.Parse(time.RFC3339Nano, firstSeen)
|
||||
s.LastSeen, _ = time.Parse(time.RFC3339Nano, lastSeen)
|
||||
sensors = append(sensors, s)
|
||||
}
|
||||
return sensors, rows.Err()
|
||||
}
|
||||
|
||||
// CountSensorsByStatus returns sensor count grouped by status.
|
||||
func (r *SOCRepo) CountSensorsByStatus() (map[soc.SensorStatus]int, error) {
|
||||
rows, err := r.db.Query("SELECT status, COUNT(*) FROM soc_sensors GROUP BY status")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
result := make(map[soc.SensorStatus]int)
|
||||
for rows.Next() {
|
||||
var status soc.SensorStatus
|
||||
var count int
|
||||
if err := rows.Scan(&status, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[status] = count
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
263
internal/infrastructure/sqlite/soc_repo_test.go
Normal file
263
internal/infrastructure/sqlite/soc_repo_test.go
Normal file
|
|
@ -0,0 +1,263 @@
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sentinel-community/gomcp/internal/domain/soc"
|
||||
)
|
||||
|
||||
func setupSOCRepo(t *testing.T) *SOCRepo {
|
||||
t.Helper()
|
||||
db, err := OpenMemory()
|
||||
if err != nil {
|
||||
t.Fatalf("open memory db: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { db.Close() })
|
||||
|
||||
repo, err := NewSOCRepo(db)
|
||||
if err != nil {
|
||||
t.Fatalf("new soc repo: %v", err)
|
||||
}
|
||||
return repo
|
||||
}
|
||||
|
||||
// === Event Tests ===
|
||||
|
||||
func TestInsertAndListEvents(t *testing.T) {
|
||||
repo := setupSOCRepo(t)
|
||||
|
||||
e1 := soc.NewSOCEvent(soc.SourceSentinelCore, soc.SeverityHigh, "jailbreak", "Jailbreak detected").
|
||||
WithSensor("core-01").WithConfidence(0.95)
|
||||
e2 := soc.NewSOCEvent(soc.SourceShield, soc.SeverityMedium, "network_block", "Connection blocked").
|
||||
WithSensor("shield-01")
|
||||
|
||||
if err := repo.InsertEvent(e1); err != nil {
|
||||
t.Fatalf("insert e1: %v", err)
|
||||
}
|
||||
if err := repo.InsertEvent(e2); err != nil {
|
||||
t.Fatalf("insert e2: %v", err)
|
||||
}
|
||||
|
||||
events, err := repo.ListEvents(10)
|
||||
if err != nil {
|
||||
t.Fatalf("list events: %v", err)
|
||||
}
|
||||
if len(events) != 2 {
|
||||
t.Errorf("expected 2 events, got %d", len(events))
|
||||
}
|
||||
|
||||
count, err := repo.CountEvents()
|
||||
if err != nil {
|
||||
t.Fatalf("count: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("expected count 2, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListEventsByCategory(t *testing.T) {
|
||||
repo := setupSOCRepo(t)
|
||||
|
||||
e1 := soc.NewSOCEvent(soc.SourceSentinelCore, soc.SeverityHigh, "jailbreak", "test")
|
||||
repo.InsertEvent(e1)
|
||||
time.Sleep(time.Millisecond)
|
||||
e2 := soc.NewSOCEvent(soc.SourceSentinelCore, soc.SeverityMedium, "injection", "test")
|
||||
repo.InsertEvent(e2)
|
||||
time.Sleep(time.Millisecond)
|
||||
e3 := soc.NewSOCEvent(soc.SourceSentinelCore, soc.SeverityLow, "jailbreak", "test2")
|
||||
repo.InsertEvent(e3)
|
||||
|
||||
events, err := repo.ListEventsByCategory("jailbreak", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("list by category: %v", err)
|
||||
}
|
||||
if len(events) != 2 {
|
||||
t.Errorf("expected 2 jailbreak events, got %d", len(events))
|
||||
}
|
||||
}
|
||||
|
||||
// === Incident Tests ===
|
||||
|
||||
func TestInsertAndGetIncident(t *testing.T) {
|
||||
repo := setupSOCRepo(t)
|
||||
|
||||
inc := soc.NewIncident("Multi-stage Jailbreak", soc.SeverityCritical, "jailbreak_chain")
|
||||
inc.SetAnchor("abc123", 5)
|
||||
|
||||
if err := repo.InsertIncident(inc); err != nil {
|
||||
t.Fatalf("insert incident: %v", err)
|
||||
}
|
||||
|
||||
got, err := repo.GetIncident(inc.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get incident: %v", err)
|
||||
}
|
||||
if got.ID != inc.ID {
|
||||
t.Errorf("ID mismatch: got %s, want %s", got.ID, inc.ID)
|
||||
}
|
||||
if got.DecisionChainAnchor != "abc123" {
|
||||
t.Errorf("anchor mismatch: got %s", got.DecisionChainAnchor)
|
||||
}
|
||||
if got.ChainLength != 5 {
|
||||
t.Errorf("chain length: got %d, want 5", got.ChainLength)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateIncidentStatus(t *testing.T) {
|
||||
repo := setupSOCRepo(t)
|
||||
|
||||
inc := soc.NewIncident("Test", soc.SeverityHigh, "test_rule")
|
||||
repo.InsertIncident(inc)
|
||||
|
||||
if err := repo.UpdateIncidentStatus(inc.ID, soc.StatusResolved); err != nil {
|
||||
t.Fatalf("update status: %v", err)
|
||||
}
|
||||
|
||||
got, err := repo.GetIncident(inc.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get after update: %v", err)
|
||||
}
|
||||
if got.Status != soc.StatusResolved {
|
||||
t.Errorf("expected RESOLVED, got %s", got.Status)
|
||||
}
|
||||
if got.ResolvedAt == nil {
|
||||
t.Error("resolved_at should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListIncidentsWithFilter(t *testing.T) {
|
||||
repo := setupSOCRepo(t)
|
||||
|
||||
inc1 := soc.NewIncident("Open Inc", soc.SeverityHigh, "rule1")
|
||||
inc2 := soc.NewIncident("Resolved Inc", soc.SeverityMedium, "rule2")
|
||||
repo.InsertIncident(inc1)
|
||||
repo.InsertIncident(inc2)
|
||||
repo.UpdateIncidentStatus(inc2.ID, soc.StatusResolved)
|
||||
|
||||
// List OPEN only
|
||||
open, err := repo.ListIncidents("OPEN", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("list open: %v", err)
|
||||
}
|
||||
if len(open) != 1 {
|
||||
t.Errorf("expected 1 open incident, got %d", len(open))
|
||||
}
|
||||
|
||||
// List all
|
||||
all, err := repo.ListIncidents("", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("list all: %v", err)
|
||||
}
|
||||
if len(all) != 2 {
|
||||
t.Errorf("expected 2 total incidents, got %d", len(all))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountOpenIncidents(t *testing.T) {
|
||||
repo := setupSOCRepo(t)
|
||||
|
||||
inc1 := soc.NewIncident("Open", soc.SeverityHigh, "r1")
|
||||
inc2 := soc.NewIncident("Investigating", soc.SeverityMedium, "r2")
|
||||
inc3 := soc.NewIncident("Resolved", soc.SeverityLow, "r3")
|
||||
repo.InsertIncident(inc1)
|
||||
repo.InsertIncident(inc2)
|
||||
repo.InsertIncident(inc3)
|
||||
repo.UpdateIncidentStatus(inc2.ID, soc.StatusInvestigating)
|
||||
repo.UpdateIncidentStatus(inc3.ID, soc.StatusResolved)
|
||||
|
||||
count, err := repo.CountOpenIncidents()
|
||||
if err != nil {
|
||||
t.Fatalf("count open: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("expected 2 open (OPEN+INVESTIGATING), got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
// === Sensor Tests ===
|
||||
|
||||
func TestUpsertAndGetSensor(t *testing.T) {
|
||||
repo := setupSOCRepo(t)
|
||||
|
||||
s := soc.NewSensor("core-01", soc.SensorTypeSentinelCore)
|
||||
s.RecordEvent()
|
||||
s.RecordEvent()
|
||||
s.RecordEvent() // Should be HEALTHY
|
||||
|
||||
if err := repo.UpsertSensor(s); err != nil {
|
||||
t.Fatalf("upsert: %v", err)
|
||||
}
|
||||
|
||||
got, err := repo.GetSensor("core-01")
|
||||
if err != nil {
|
||||
t.Fatalf("get sensor: %v", err)
|
||||
}
|
||||
if got.Status != soc.SensorStatusHealthy {
|
||||
t.Errorf("expected HEALTHY, got %s", got.Status)
|
||||
}
|
||||
if got.EventCount != 3 {
|
||||
t.Errorf("expected 3 events, got %d", got.EventCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSensorUpsertUpdate(t *testing.T) {
|
||||
repo := setupSOCRepo(t)
|
||||
|
||||
s := soc.NewSensor("shield-01", soc.SensorTypeShield)
|
||||
repo.UpsertSensor(s)
|
||||
|
||||
// Update with new status
|
||||
s.RecordEvent()
|
||||
s.RecordEvent()
|
||||
s.RecordEvent()
|
||||
repo.UpsertSensor(s)
|
||||
|
||||
got, err := repo.GetSensor("shield-01")
|
||||
if err != nil {
|
||||
t.Fatalf("get sensor: %v", err)
|
||||
}
|
||||
if got.EventCount != 3 {
|
||||
t.Errorf("upsert should update event_count, got %d", got.EventCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListSensors(t *testing.T) {
|
||||
repo := setupSOCRepo(t)
|
||||
|
||||
repo.UpsertSensor(soc.NewSensor("core-01", soc.SensorTypeSentinelCore))
|
||||
repo.UpsertSensor(soc.NewSensor("shield-01", soc.SensorTypeShield))
|
||||
|
||||
sensors, err := repo.ListSensors()
|
||||
if err != nil {
|
||||
t.Fatalf("list: %v", err)
|
||||
}
|
||||
if len(sensors) != 2 {
|
||||
t.Errorf("expected 2, got %d", len(sensors))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountSensorsByStatus(t *testing.T) {
|
||||
repo := setupSOCRepo(t)
|
||||
|
||||
s1 := soc.NewSensor("core-01", soc.SensorTypeSentinelCore)
|
||||
s1.RecordEvent()
|
||||
s1.RecordEvent()
|
||||
s1.RecordEvent() // HEALTHY
|
||||
|
||||
s2 := soc.NewSensor("shield-01", soc.SensorTypeShield) // UNKNOWN
|
||||
|
||||
repo.UpsertSensor(s1)
|
||||
repo.UpsertSensor(s2)
|
||||
|
||||
counts, err := repo.CountSensorsByStatus()
|
||||
if err != nil {
|
||||
t.Fatalf("count by status: %v", err)
|
||||
}
|
||||
if counts[soc.SensorStatusHealthy] != 1 {
|
||||
t.Errorf("expected 1 HEALTHY, got %d", counts[soc.SensorStatusHealthy])
|
||||
}
|
||||
if counts[soc.SensorStatusUnknown] != 1 {
|
||||
t.Errorf("expected 1 UNKNOWN, got %d", counts[soc.SensorStatusUnknown])
|
||||
}
|
||||
}
|
||||
215
internal/infrastructure/sqlite/state_repo.go
Normal file
215
internal/infrastructure/sqlite/state_repo.go
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/sentinel-community/gomcp/internal/domain/session"
|
||||
)
|
||||
|
||||
// StateRepo implements session.StateStore using SQLite.
|
||||
// Compatible with memory_bridge.db schema (states + audit_log).
|
||||
// NOTE: The Python version uses AES-256-GCM encryption on the data blob.
|
||||
// This Go implementation stores plaintext JSON for now — encryption
|
||||
// can be layered on top via a decorator if needed.
|
||||
type StateRepo struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
// NewStateRepo creates a StateRepo and ensures the schema exists.
|
||||
func NewStateRepo(db *DB) (*StateRepo, error) {
|
||||
repo := &StateRepo{db: db}
|
||||
if err := repo.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("state repo migrate: %w", err)
|
||||
}
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
func (r *StateRepo) migrate() error {
|
||||
stmts := []string{
|
||||
`CREATE TABLE IF NOT EXISTS states (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT NOT NULL,
|
||||
version INTEGER NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
checksum TEXT NOT NULL,
|
||||
data BLOB NOT NULL,
|
||||
nonce BLOB,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(session_id, version)
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_session_id ON states(session_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_timestamp ON states(timestamp)`,
|
||||
`CREATE TABLE IF NOT EXISTS audit_log (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
version INTEGER,
|
||||
timestamp TEXT NOT NULL,
|
||||
details TEXT,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_audit_session ON audit_log(session_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON audit_log(timestamp)`,
|
||||
}
|
||||
for _, s := range stmts {
|
||||
if _, err := r.db.Exec(s); err != nil {
|
||||
return fmt.Errorf("exec migration: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save persists a cognitive state vector snapshot.
|
||||
func (r *StateRepo) Save(ctx context.Context, state *session.CognitiveStateVector, checksum string) error {
|
||||
data, err := json.Marshal(state)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal state: %w", err)
|
||||
}
|
||||
|
||||
// Verify checksum if provided, or compute one.
|
||||
if checksum == "" {
|
||||
h := sha256.Sum256(data)
|
||||
checksum = hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
now := time.Now().Format(timeFormat)
|
||||
|
||||
// Determine action for audit log.
|
||||
var action string
|
||||
var existingCount int
|
||||
row := r.db.QueryRow(`SELECT COUNT(*) FROM states WHERE session_id = ?`, state.SessionID)
|
||||
if err := row.Scan(&existingCount); err != nil {
|
||||
return fmt.Errorf("count existing: %w", err)
|
||||
}
|
||||
if existingCount == 0 {
|
||||
action = "create"
|
||||
} else {
|
||||
action = "update"
|
||||
}
|
||||
|
||||
_, err = r.db.Exec(`INSERT INTO states (session_id, version, timestamp, checksum, data)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
state.SessionID, state.Version, now, checksum, data,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert state: %w", err)
|
||||
}
|
||||
|
||||
// Write audit log entry.
|
||||
_, err = r.db.Exec(`INSERT INTO audit_log (session_id, action, version, timestamp, details)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
state.SessionID, action, state.Version, now,
|
||||
fmt.Sprintf("%s session %s v%d", action, state.SessionID, state.Version),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert audit: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load retrieves a cognitive state vector. If version is nil, loads the latest.
|
||||
func (r *StateRepo) Load(ctx context.Context, sessionID string, version *int) (*session.CognitiveStateVector, string, error) {
|
||||
var row *sql.Row
|
||||
if version != nil {
|
||||
row = r.db.QueryRow(`SELECT data, checksum FROM states WHERE session_id = ? AND version = ?`,
|
||||
sessionID, *version)
|
||||
} else {
|
||||
row = r.db.QueryRow(`SELECT data, checksum FROM states WHERE session_id = ? ORDER BY version DESC LIMIT 1`,
|
||||
sessionID)
|
||||
}
|
||||
|
||||
var data []byte
|
||||
var checksum string
|
||||
if err := row.Scan(&data, &checksum); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, "", fmt.Errorf("session %s not found", sessionID)
|
||||
}
|
||||
return nil, "", fmt.Errorf("scan state: %w", err)
|
||||
}
|
||||
|
||||
var state session.CognitiveStateVector
|
||||
if err := json.Unmarshal(data, &state); err != nil {
|
||||
return nil, "", fmt.Errorf("unmarshal state: %w", err)
|
||||
}
|
||||
|
||||
return &state, checksum, nil
|
||||
}
|
||||
|
||||
// ListSessions returns metadata about all persisted sessions.
|
||||
func (r *StateRepo) ListSessions(ctx context.Context) ([]session.SessionInfo, error) {
|
||||
rows, err := r.db.Query(`SELECT session_id, MAX(version) as version, MAX(timestamp) as updated_at
|
||||
FROM states GROUP BY session_id ORDER BY updated_at DESC`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list sessions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sessions []session.SessionInfo
|
||||
for rows.Next() {
|
||||
var info session.SessionInfo
|
||||
var updatedAt string
|
||||
if err := rows.Scan(&info.SessionID, &info.Version, &updatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan session info: %w", err)
|
||||
}
|
||||
info.UpdatedAt, _ = time.Parse(timeFormat, updatedAt)
|
||||
sessions = append(sessions, info)
|
||||
}
|
||||
return sessions, rows.Err()
|
||||
}
|
||||
|
||||
// DeleteSession removes all versions of a session. Returns the number of deleted rows.
|
||||
func (r *StateRepo) DeleteSession(ctx context.Context, sessionID string) (int, error) {
|
||||
now := time.Now().Format(timeFormat)
|
||||
|
||||
result, err := r.db.Exec(`DELETE FROM states WHERE session_id = ?`, sessionID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("delete session: %w", err)
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
|
||||
// Audit log.
|
||||
_, _ = r.db.Exec(`INSERT INTO audit_log (session_id, action, timestamp, details)
|
||||
VALUES (?, 'delete', ?, ?)`,
|
||||
sessionID, now, fmt.Sprintf("deleted %d versions of session %s", n, sessionID),
|
||||
)
|
||||
|
||||
return int(n), nil
|
||||
}
|
||||
|
||||
// GetAuditLog returns the audit log for a session.
|
||||
func (r *StateRepo) GetAuditLog(ctx context.Context, sessionID string, limit int) ([]session.AuditEntry, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
rows, err := r.db.Query(`SELECT session_id, action, version, timestamp, details
|
||||
FROM audit_log WHERE session_id = ? ORDER BY id DESC LIMIT ?`,
|
||||
sessionID, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get audit log: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var entries []session.AuditEntry
|
||||
for rows.Next() {
|
||||
var e session.AuditEntry
|
||||
var version sql.NullInt64
|
||||
if err := rows.Scan(&e.SessionID, &e.Action, &version, &e.Timestamp, &e.Details); err != nil {
|
||||
return nil, fmt.Errorf("scan audit entry: %w", err)
|
||||
}
|
||||
if version.Valid {
|
||||
e.Version = int(version.Int64)
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
// Ensure StateRepo implements session.StateStore.
|
||||
var _ session.StateStore = (*StateRepo)(nil)
|
||||
163
internal/infrastructure/sqlite/state_repo_test.go
Normal file
163
internal/infrastructure/sqlite/state_repo_test.go
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/sentinel-community/gomcp/internal/domain/session"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestStateRepo(t *testing.T) *StateRepo {
|
||||
t.Helper()
|
||||
db, err := OpenMemory()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { db.Close() })
|
||||
|
||||
repo, err := NewStateRepo(db)
|
||||
require.NoError(t, err)
|
||||
return repo
|
||||
}
|
||||
|
||||
func TestStateRepo_Save_Load(t *testing.T) {
|
||||
repo := newTestStateRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
csv := session.NewCognitiveStateVector("test-session")
|
||||
csv.SetGoal("Build GoMCP", 0.3)
|
||||
csv.AddHypothesis("SQLite is fast enough")
|
||||
csv.AddDecision("Use mcp-go", "mature lib", []string{"custom"})
|
||||
csv.AddFact("Go 1.25", "requirement", 1.0)
|
||||
|
||||
checksum := csv.Checksum()
|
||||
err := repo.Save(ctx, csv, checksum)
|
||||
require.NoError(t, err)
|
||||
|
||||
loaded, gotChecksum, err := repo.Load(ctx, "test-session", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, loaded)
|
||||
|
||||
assert.Equal(t, csv.SessionID, loaded.SessionID)
|
||||
assert.Equal(t, csv.Version, loaded.Version)
|
||||
assert.Equal(t, checksum, gotChecksum)
|
||||
require.NotNil(t, loaded.PrimaryGoal)
|
||||
assert.Equal(t, "Build GoMCP", loaded.PrimaryGoal.Description)
|
||||
assert.Len(t, loaded.Hypotheses, 1)
|
||||
assert.Len(t, loaded.Decisions, 1)
|
||||
assert.Len(t, loaded.Facts, 1)
|
||||
}
|
||||
|
||||
func TestStateRepo_Save_Versioning(t *testing.T) {
|
||||
repo := newTestStateRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
csv := session.NewCognitiveStateVector("s1")
|
||||
csv.SetGoal("v1", 0.1)
|
||||
require.NoError(t, repo.Save(ctx, csv, csv.Checksum()))
|
||||
|
||||
csv.BumpVersion()
|
||||
csv.SetGoal("v2", 0.5)
|
||||
require.NoError(t, repo.Save(ctx, csv, csv.Checksum()))
|
||||
|
||||
// Load latest
|
||||
loaded, _, err := repo.Load(ctx, "s1", nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, loaded.Version)
|
||||
assert.Equal(t, "v2", loaded.PrimaryGoal.Description)
|
||||
|
||||
// Load specific version
|
||||
v := 1
|
||||
loaded, _, err = repo.Load(ctx, "s1", &v)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, loaded.Version)
|
||||
assert.Equal(t, "v1", loaded.PrimaryGoal.Description)
|
||||
}
|
||||
|
||||
func TestStateRepo_Load_NotFound(t *testing.T) {
|
||||
repo := newTestStateRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
loaded, _, err := repo.Load(ctx, "nonexistent", nil)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, loaded)
|
||||
}
|
||||
|
||||
func TestStateRepo_ListSessions(t *testing.T) {
|
||||
repo := newTestStateRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
s1 := session.NewCognitiveStateVector("session-1")
|
||||
s2 := session.NewCognitiveStateVector("session-2")
|
||||
require.NoError(t, repo.Save(ctx, s1, s1.Checksum()))
|
||||
require.NoError(t, repo.Save(ctx, s2, s2.Checksum()))
|
||||
|
||||
sessions, err := repo.ListSessions(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, sessions, 2)
|
||||
}
|
||||
|
||||
func TestStateRepo_DeleteSession(t *testing.T) {
|
||||
repo := newTestStateRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
csv := session.NewCognitiveStateVector("to-delete")
|
||||
require.NoError(t, repo.Save(ctx, csv, csv.Checksum()))
|
||||
|
||||
csv.BumpVersion()
|
||||
require.NoError(t, repo.Save(ctx, csv, csv.Checksum()))
|
||||
|
||||
count, err := repo.DeleteSession(ctx, "to-delete")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
loaded, _, err := repo.Load(ctx, "to-delete", nil)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, loaded)
|
||||
}
|
||||
|
||||
func TestStateRepo_AuditLog(t *testing.T) {
|
||||
repo := newTestStateRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
csv := session.NewCognitiveStateVector("audited")
|
||||
require.NoError(t, repo.Save(ctx, csv, csv.Checksum()))
|
||||
|
||||
csv.BumpVersion()
|
||||
require.NoError(t, repo.Save(ctx, csv, csv.Checksum()))
|
||||
|
||||
log, err := repo.GetAuditLog(ctx, "audited", 10)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(log), 2)
|
||||
assert.Equal(t, "audited", log[0].SessionID)
|
||||
}
|
||||
|
||||
func TestStateRepo_ComplexState_RoundTrip(t *testing.T) {
|
||||
repo := newTestStateRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
csv := session.NewCognitiveStateVector("complex")
|
||||
csv.SetGoal("Build full system", 0.7)
|
||||
csv.AddHypothesis("H1")
|
||||
csv.AddHypothesis("H2")
|
||||
csv.AddDecision("D1", "R1", []string{"A1", "A2"})
|
||||
csv.AddDecision("D2", "R2", nil)
|
||||
csv.AddFact("F1", "requirement", 0.9)
|
||||
csv.AddFact("F2", "decision", 1.0)
|
||||
csv.AddFact("F3", "context", 0.5)
|
||||
csv.OpenQuestions = []string{"Q1", "Q2", "Q3"}
|
||||
csv.ConfidenceMap["area1"] = 0.8
|
||||
csv.ConfidenceMap["area2"] = 0.3
|
||||
|
||||
require.NoError(t, repo.Save(ctx, csv, csv.Checksum()))
|
||||
|
||||
loaded, _, err := repo.Load(ctx, "complex", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "Build full system", loaded.PrimaryGoal.Description)
|
||||
assert.Len(t, loaded.Hypotheses, 2)
|
||||
assert.Len(t, loaded.Decisions, 2)
|
||||
assert.Len(t, loaded.Facts, 3)
|
||||
assert.Len(t, loaded.OpenQuestions, 3)
|
||||
assert.InDelta(t, 0.8, loaded.ConfidenceMap["area1"], 0.001)
|
||||
}
|
||||
133
internal/infrastructure/sqlite/synapse_repo.go
Normal file
133
internal/infrastructure/sqlite/synapse_repo.go
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/sentinel-community/gomcp/internal/domain/synapse"
|
||||
)
|
||||
|
||||
// SynapseRepo implements synapse.SynapseStore using SQLite.
|
||||
type SynapseRepo struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
// NewSynapseRepo creates a SynapseRepo (table created by FactRepo migration v3.3).
|
||||
func NewSynapseRepo(db *DB) *SynapseRepo {
|
||||
return &SynapseRepo{db: db}
|
||||
}
|
||||
|
||||
// Create inserts a new PENDING synapse.
|
||||
func (r *SynapseRepo) Create(ctx context.Context, factIDA, factIDB string, confidence float64) (int64, error) {
|
||||
result, err := r.db.Exec(
|
||||
`INSERT INTO synapses (fact_id_a, fact_id_b, confidence, status, created_at)
|
||||
VALUES (?, ?, ?, 'PENDING', ?)`,
|
||||
factIDA, factIDB, confidence, time.Now().Format(timeFormat))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("create synapse: %w", err)
|
||||
}
|
||||
id, _ := result.LastInsertId()
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// ListPending returns synapses with status PENDING.
|
||||
func (r *SynapseRepo) ListPending(ctx context.Context, limit int) ([]*synapse.Synapse, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
rows, err := r.db.Query(
|
||||
`SELECT id, fact_id_a, fact_id_b, confidence, status, created_at
|
||||
FROM synapses WHERE status = 'PENDING' ORDER BY confidence DESC LIMIT ?`, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list pending: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanSynapses(rows)
|
||||
}
|
||||
|
||||
// Accept transitions a synapse to VERIFIED.
|
||||
func (r *SynapseRepo) Accept(ctx context.Context, id int64) error {
|
||||
result, err := r.db.Exec(`UPDATE synapses SET status = 'VERIFIED' WHERE id = ? AND status = 'PENDING'`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("accept synapse: %w", err)
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
if n == 0 {
|
||||
return fmt.Errorf("synapse %d not found or not PENDING", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reject transitions a synapse to REJECTED.
|
||||
func (r *SynapseRepo) Reject(ctx context.Context, id int64) error {
|
||||
result, err := r.db.Exec(`UPDATE synapses SET status = 'REJECTED' WHERE id = ? AND status = 'PENDING'`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reject synapse: %w", err)
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
if n == 0 {
|
||||
return fmt.Errorf("synapse %d not found or not PENDING", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListVerified returns all VERIFIED synapses.
|
||||
func (r *SynapseRepo) ListVerified(ctx context.Context) ([]*synapse.Synapse, error) {
|
||||
rows, err := r.db.Query(
|
||||
`SELECT id, fact_id_a, fact_id_b, confidence, status, created_at
|
||||
FROM synapses WHERE status = 'VERIFIED' ORDER BY confidence DESC`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list verified: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanSynapses(rows)
|
||||
}
|
||||
|
||||
// Count returns synapse counts by status.
|
||||
func (r *SynapseRepo) Count(ctx context.Context) (pending, verified, rejected int, err error) {
|
||||
row := r.db.QueryRow(`SELECT
|
||||
COALESCE(SUM(CASE WHEN status='PENDING' THEN 1 ELSE 0 END), 0),
|
||||
COALESCE(SUM(CASE WHEN status='VERIFIED' THEN 1 ELSE 0 END), 0),
|
||||
COALESCE(SUM(CASE WHEN status='REJECTED' THEN 1 ELSE 0 END), 0)
|
||||
FROM synapses`)
|
||||
err = row.Scan(&pending, &verified, &rejected)
|
||||
if err != nil {
|
||||
return 0, 0, 0, fmt.Errorf("count synapses: %w", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Exists checks if a synapse exists between two facts (either direction).
|
||||
func (r *SynapseRepo) Exists(ctx context.Context, factIDA, factIDB string) (bool, error) {
|
||||
var count int
|
||||
row := r.db.QueryRow(
|
||||
`SELECT COUNT(*) FROM synapses
|
||||
WHERE (fact_id_a = ? AND fact_id_b = ?) OR (fact_id_a = ? AND fact_id_b = ?)`,
|
||||
factIDA, factIDB, factIDB, factIDA)
|
||||
if err := row.Scan(&count); err != nil {
|
||||
return false, fmt.Errorf("exists synapse: %w", err)
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// Ensure SynapseRepo implements synapse.SynapseStore.
|
||||
var _ synapse.SynapseStore = (*SynapseRepo)(nil)
|
||||
|
||||
func scanSynapses(rows interface {
|
||||
Next() bool
|
||||
Scan(...any) error
|
||||
}) ([]*synapse.Synapse, error) {
|
||||
var result []*synapse.Synapse
|
||||
for rows.Next() {
|
||||
var s synapse.Synapse
|
||||
var status, createdAt string
|
||||
if err := rows.Scan(&s.ID, &s.FactIDA, &s.FactIDB, &s.Confidence, &status, &createdAt); err != nil {
|
||||
return nil, fmt.Errorf("scan synapse: %w", err)
|
||||
}
|
||||
s.Status = synapse.Status(status)
|
||||
s.CreatedAt, _ = time.Parse(timeFormat, createdAt)
|
||||
result = append(result, &s)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
164
internal/infrastructure/sqlite/synapse_repo_test.go
Normal file
164
internal/infrastructure/sqlite/synapse_repo_test.go
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
package sqlite_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/sentinel-community/gomcp/internal/domain/memory"
|
||||
"github.com/sentinel-community/gomcp/internal/infrastructure/sqlite"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func setupSynapseTest(t *testing.T) (*sqlite.SynapseRepo, *sqlite.FactRepo) {
|
||||
t.Helper()
|
||||
db, err := sqlite.OpenMemory()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { db.Close() })
|
||||
|
||||
// FactRepo migration creates synapses table.
|
||||
factRepo, err := sqlite.NewFactRepo(db)
|
||||
require.NoError(t, err)
|
||||
|
||||
synapseRepo := sqlite.NewSynapseRepo(db)
|
||||
return synapseRepo, factRepo
|
||||
}
|
||||
|
||||
func TestSynapseRepo_CreateAndListPending(t *testing.T) {
|
||||
repo, factRepo := setupSynapseTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create two facts to link.
|
||||
f1 := memory.NewFact("Architecture overview", memory.LevelDomain, "arch", "")
|
||||
f2 := memory.NewFact("Security module design", memory.LevelDomain, "security", "")
|
||||
require.NoError(t, factRepo.Add(ctx, f1))
|
||||
require.NoError(t, factRepo.Add(ctx, f2))
|
||||
|
||||
// Create synapse.
|
||||
id, err := repo.Create(ctx, f1.ID, f2.ID, 0.92)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
|
||||
// List pending.
|
||||
pending, err := repo.ListPending(ctx, 10)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, pending, 1)
|
||||
assert.Equal(t, f1.ID, pending[0].FactIDA)
|
||||
assert.Equal(t, f2.ID, pending[0].FactIDB)
|
||||
assert.InDelta(t, 0.92, pending[0].Confidence, 0.01)
|
||||
assert.Equal(t, "PENDING", string(pending[0].Status))
|
||||
}
|
||||
|
||||
func TestSynapseRepo_AcceptAndListVerified(t *testing.T) {
|
||||
repo, factRepo := setupSynapseTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
f1 := memory.NewFact("fact A", memory.LevelModule, "test", "")
|
||||
f2 := memory.NewFact("fact B", memory.LevelModule, "test", "")
|
||||
require.NoError(t, factRepo.Add(ctx, f1))
|
||||
require.NoError(t, factRepo.Add(ctx, f2))
|
||||
|
||||
id, err := repo.Create(ctx, f1.ID, f2.ID, 0.88)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Accept.
|
||||
require.NoError(t, repo.Accept(ctx, id))
|
||||
|
||||
// Should no longer be in pending.
|
||||
pending, _ := repo.ListPending(ctx, 10)
|
||||
assert.Empty(t, pending)
|
||||
|
||||
// Should be in verified.
|
||||
verified, err := repo.ListVerified(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, verified, 1)
|
||||
assert.Equal(t, "VERIFIED", string(verified[0].Status))
|
||||
}
|
||||
|
||||
func TestSynapseRepo_Reject(t *testing.T) {
|
||||
repo, factRepo := setupSynapseTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
f1 := memory.NewFact("fact X", memory.LevelProject, "", "")
|
||||
f2 := memory.NewFact("fact Y", memory.LevelProject, "", "")
|
||||
require.NoError(t, factRepo.Add(ctx, f1))
|
||||
require.NoError(t, factRepo.Add(ctx, f2))
|
||||
|
||||
id, err := repo.Create(ctx, f1.ID, f2.ID, 0.50)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, repo.Reject(ctx, id))
|
||||
|
||||
pending, _ := repo.ListPending(ctx, 10)
|
||||
assert.Empty(t, pending)
|
||||
|
||||
verified, _ := repo.ListVerified(ctx)
|
||||
assert.Empty(t, verified)
|
||||
}
|
||||
|
||||
func TestSynapseRepo_Count(t *testing.T) {
|
||||
repo, factRepo := setupSynapseTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
f1 := memory.NewFact("a", memory.LevelProject, "", "")
|
||||
f2 := memory.NewFact("b", memory.LevelProject, "", "")
|
||||
f3 := memory.NewFact("c", memory.LevelProject, "", "")
|
||||
require.NoError(t, factRepo.Add(ctx, f1))
|
||||
require.NoError(t, factRepo.Add(ctx, f2))
|
||||
require.NoError(t, factRepo.Add(ctx, f3))
|
||||
|
||||
id1, _ := repo.Create(ctx, f1.ID, f2.ID, 0.90)
|
||||
id2, _ := repo.Create(ctx, f1.ID, f3.ID, 0.85)
|
||||
_, _ = repo.Create(ctx, f2.ID, f3.ID, 0.40)
|
||||
|
||||
_ = repo.Accept(ctx, id1)
|
||||
_ = repo.Reject(ctx, id2)
|
||||
|
||||
pending, verified, rejected, err := repo.Count(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, pending)
|
||||
assert.Equal(t, 1, verified)
|
||||
assert.Equal(t, 1, rejected)
|
||||
}
|
||||
|
||||
func TestSynapseRepo_Exists(t *testing.T) {
|
||||
repo, factRepo := setupSynapseTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
f1 := memory.NewFact("p", memory.LevelProject, "", "")
|
||||
f2 := memory.NewFact("q", memory.LevelProject, "", "")
|
||||
require.NoError(t, factRepo.Add(ctx, f1))
|
||||
require.NoError(t, factRepo.Add(ctx, f2))
|
||||
|
||||
// No synapse yet.
|
||||
exists, err := repo.Exists(ctx, f1.ID, f2.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
// Create one.
|
||||
_, _ = repo.Create(ctx, f1.ID, f2.ID, 0.95)
|
||||
|
||||
// Should exist in both directions.
|
||||
exists, _ = repo.Exists(ctx, f1.ID, f2.ID)
|
||||
assert.True(t, exists)
|
||||
|
||||
exists, _ = repo.Exists(ctx, f2.ID, f1.ID)
|
||||
assert.True(t, exists, "bidirectional check should work")
|
||||
}
|
||||
|
||||
func TestSynapseRepo_AcceptNonPending_Fails(t *testing.T) {
|
||||
repo, factRepo := setupSynapseTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
f1 := memory.NewFact("m", memory.LevelProject, "", "")
|
||||
f2 := memory.NewFact("n", memory.LevelProject, "", "")
|
||||
require.NoError(t, factRepo.Add(ctx, f1))
|
||||
require.NoError(t, factRepo.Add(ctx, f2))
|
||||
|
||||
id, _ := repo.Create(ctx, f1.ID, f2.ID, 0.80)
|
||||
_ = repo.Accept(ctx, id)
|
||||
|
||||
// Trying to accept again should fail.
|
||||
err := repo.Accept(ctx, id)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue