mirror of
https://github.com/syntrex-lab/gomcp.git
synced 2026-04-26 04:46:22 +02:00
219 lines
6.5 KiB
Go
219 lines
6.5 KiB
Go
// Copyright 2026 Syntrex Lab. All rights reserved.
|
|
// Use of this source code is governed by an Apache-2.0 license
|
|
// that can be found in the LICENSE file.
|
|
|
|
package sqlite
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/syntrex-lab/gomcp/internal/domain/session"
|
|
)
|
|
|
|
// StateRepo implements session.StateStore using SQLite.
|
|
// Compatible with memory_bridge.db schema (states + audit_log).
|
|
// NOTE: The Python version uses AES-256-GCM encryption on the data blob.
|
|
// This Go implementation stores plaintext JSON for now — encryption
|
|
// can be layered on top via a decorator if needed.
|
|
type StateRepo struct {
|
|
db *DB
|
|
}
|
|
|
|
// NewStateRepo creates a StateRepo and ensures the schema exists.
|
|
func NewStateRepo(db *DB) (*StateRepo, error) {
|
|
repo := &StateRepo{db: db}
|
|
if err := repo.migrate(); err != nil {
|
|
return nil, fmt.Errorf("state repo migrate: %w", err)
|
|
}
|
|
return repo, nil
|
|
}
|
|
|
|
func (r *StateRepo) migrate() error {
|
|
stmts := []string{
|
|
`CREATE TABLE IF NOT EXISTS states (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
session_id TEXT NOT NULL,
|
|
version INTEGER NOT NULL,
|
|
timestamp TEXT NOT NULL,
|
|
checksum TEXT NOT NULL,
|
|
data BLOB NOT NULL,
|
|
nonce BLOB,
|
|
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
|
UNIQUE(session_id, version)
|
|
)`,
|
|
`CREATE INDEX IF NOT EXISTS idx_session_id ON states(session_id)`,
|
|
`CREATE INDEX IF NOT EXISTS idx_timestamp ON states(timestamp)`,
|
|
`CREATE TABLE IF NOT EXISTS audit_log (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
session_id TEXT NOT NULL,
|
|
action TEXT NOT NULL,
|
|
version INTEGER,
|
|
timestamp TEXT NOT NULL,
|
|
details TEXT,
|
|
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
|
)`,
|
|
`CREATE INDEX IF NOT EXISTS idx_audit_session ON audit_log(session_id)`,
|
|
`CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON audit_log(timestamp)`,
|
|
}
|
|
for _, s := range stmts {
|
|
if _, err := r.db.Exec(s); err != nil {
|
|
return fmt.Errorf("exec migration: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Save persists a cognitive state vector snapshot.
|
|
func (r *StateRepo) Save(ctx context.Context, state *session.CognitiveStateVector, checksum string) error {
|
|
data, err := json.Marshal(state)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal state: %w", err)
|
|
}
|
|
|
|
// Verify checksum if provided, or compute one.
|
|
if checksum == "" {
|
|
h := sha256.Sum256(data)
|
|
checksum = hex.EncodeToString(h[:])
|
|
}
|
|
|
|
now := time.Now().Format(timeFormat)
|
|
|
|
// Determine action for audit log.
|
|
var action string
|
|
var existingCount int
|
|
row := r.db.QueryRow(`SELECT COUNT(*) FROM states WHERE session_id = ?`, state.SessionID)
|
|
if err := row.Scan(&existingCount); err != nil {
|
|
return fmt.Errorf("count existing: %w", err)
|
|
}
|
|
if existingCount == 0 {
|
|
action = "create"
|
|
} else {
|
|
action = "update"
|
|
}
|
|
|
|
_, err = r.db.Exec(`INSERT INTO states (session_id, version, timestamp, checksum, data)
|
|
VALUES (?, ?, ?, ?, ?)`,
|
|
state.SessionID, state.Version, now, checksum, data,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("insert state: %w", err)
|
|
}
|
|
|
|
// Write audit log entry.
|
|
_, err = r.db.Exec(`INSERT INTO audit_log (session_id, action, version, timestamp, details)
|
|
VALUES (?, ?, ?, ?, ?)`,
|
|
state.SessionID, action, state.Version, now,
|
|
fmt.Sprintf("%s session %s v%d", action, state.SessionID, state.Version),
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("insert audit: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Load retrieves a cognitive state vector. If version is nil, loads the latest.
|
|
func (r *StateRepo) Load(ctx context.Context, sessionID string, version *int) (*session.CognitiveStateVector, string, error) {
|
|
var row *sql.Row
|
|
if version != nil {
|
|
row = r.db.QueryRow(`SELECT data, checksum FROM states WHERE session_id = ? AND version = ?`,
|
|
sessionID, *version)
|
|
} else {
|
|
row = r.db.QueryRow(`SELECT data, checksum FROM states WHERE session_id = ? ORDER BY version DESC LIMIT 1`,
|
|
sessionID)
|
|
}
|
|
|
|
var data []byte
|
|
var checksum string
|
|
if err := row.Scan(&data, &checksum); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, "", fmt.Errorf("session %s not found", sessionID)
|
|
}
|
|
return nil, "", fmt.Errorf("scan state: %w", err)
|
|
}
|
|
|
|
var state session.CognitiveStateVector
|
|
if err := json.Unmarshal(data, &state); err != nil {
|
|
return nil, "", fmt.Errorf("unmarshal state: %w", err)
|
|
}
|
|
|
|
return &state, checksum, nil
|
|
}
|
|
|
|
// ListSessions returns metadata about all persisted sessions.
|
|
func (r *StateRepo) ListSessions(ctx context.Context) ([]session.SessionInfo, error) {
|
|
rows, err := r.db.Query(`SELECT session_id, MAX(version) as version, MAX(timestamp) as updated_at
|
|
FROM states GROUP BY session_id ORDER BY updated_at DESC`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list sessions: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var sessions []session.SessionInfo
|
|
for rows.Next() {
|
|
var info session.SessionInfo
|
|
var updatedAt string
|
|
if err := rows.Scan(&info.SessionID, &info.Version, &updatedAt); err != nil {
|
|
return nil, fmt.Errorf("scan session info: %w", err)
|
|
}
|
|
info.UpdatedAt, _ = time.Parse(timeFormat, updatedAt)
|
|
sessions = append(sessions, info)
|
|
}
|
|
return sessions, rows.Err()
|
|
}
|
|
|
|
// DeleteSession removes all versions of a session. Returns the number of deleted rows.
|
|
func (r *StateRepo) DeleteSession(ctx context.Context, sessionID string) (int, error) {
|
|
now := time.Now().Format(timeFormat)
|
|
|
|
result, err := r.db.Exec(`DELETE FROM states WHERE session_id = ?`, sessionID)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("delete session: %w", err)
|
|
}
|
|
n, _ := result.RowsAffected()
|
|
|
|
// Audit log.
|
|
_, _ = r.db.Exec(`INSERT INTO audit_log (session_id, action, timestamp, details)
|
|
VALUES (?, 'delete', ?, ?)`,
|
|
sessionID, now, fmt.Sprintf("deleted %d versions of session %s", n, sessionID),
|
|
)
|
|
|
|
return int(n), nil
|
|
}
|
|
|
|
// GetAuditLog returns the audit log for a session.
|
|
func (r *StateRepo) GetAuditLog(ctx context.Context, sessionID string, limit int) ([]session.AuditEntry, error) {
|
|
if limit <= 0 {
|
|
limit = 50
|
|
}
|
|
rows, err := r.db.Query(`SELECT session_id, action, version, timestamp, details
|
|
FROM audit_log WHERE session_id = ? ORDER BY id DESC LIMIT ?`,
|
|
sessionID, limit)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get audit log: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var entries []session.AuditEntry
|
|
for rows.Next() {
|
|
var e session.AuditEntry
|
|
var version sql.NullInt64
|
|
if err := rows.Scan(&e.SessionID, &e.Action, &version, &e.Timestamp, &e.Details); err != nil {
|
|
return nil, fmt.Errorf("scan audit entry: %w", err)
|
|
}
|
|
if version.Valid {
|
|
e.Version = int(version.Int64)
|
|
}
|
|
entries = append(entries, e)
|
|
}
|
|
return entries, rows.Err()
|
|
}
|
|
|
|
// Ensure StateRepo implements session.StateStore.
|
|
var _ session.StateStore = (*StateRepo)(nil)
|