gomcp/internal/infrastructure/onnx/embedder.go
DmitrL-dev 694e32be26 refactor: rename identity to syntrex, add root orchestration and CI/CD
- Rename Go module: sentinel-community/gomcp -> syntrex/gomcp (50+ files)
- Rename npm package: sentinel-dashboard -> syntrex-dashboard
- Update Cargo.toml repository URL to syntrex/syntrex
- Update all doc references from DmitrL-dev/AISecurity to syntrex
- Add root Makefile (build-all, test-all, lint-all, clean-all)
- Add MIT LICENSE
- Add .editorconfig (Go/Rust/TS/C cross-language)
- Add .github/workflows/ci.yml (Go + Rust + Dashboard)
- Add dashboard next.config.ts and .env.example
- Clean ARCHITECTURE.md: remove brain/immune/strike/micro-swarm, fix 61->67 engines
2026-03-11 15:30:49 +10:00

228 lines
5.6 KiB
Go

//go:build onnx
package onnx
import (
"context"
"fmt"
"log"
"math"
"sync"
ort "github.com/yalue/onnxruntime_go"
"github.com/syntrex/gomcp/internal/domain/vectorstore"
)
// Embedder implements vectorstore.Embedder using ONNX Runtime.
// Runs paraphrase-multilingual-MiniLM-L12-v2 inference locally.
type Embedder struct {
mu sync.Mutex
session *ort.DynamicAdvancedSession
tokenizer *Tokenizer
dimension int
modelName string
}
// Config holds ONNX embedder configuration.
type Config struct {
RlmDir string // Path to .rlm directory (for model discovery)
ModelPath string // Override: direct path to .onnx model
VocabPath string // Override: direct path to vocab.txt
MaxSeqLen int // Max sequence length (default: 128)
}
// NewEmbedder creates an ONNX-based embedder.
// Returns an error if the model or runtime cannot be loaded.
// Caller should fall back to FTS5Embedder on error.
func NewEmbedder(cfg Config) (*Embedder, error) {
// Discover model paths if not explicitly provided.
paths := &ModelPaths{
ModelPath: cfg.ModelPath,
VocabPath: cfg.VocabPath,
}
if paths.ModelPath == "" || paths.VocabPath == "" {
discovered, err := DiscoverModel(cfg.RlmDir)
if err != nil {
return nil, fmt.Errorf("onnx discovery: %w", err)
}
if paths.ModelPath == "" {
paths.ModelPath = discovered.ModelPath
}
if paths.VocabPath == "" {
paths.VocabPath = discovered.VocabPath
}
// Set runtime path for ONNX Runtime initialization.
if discovered.RuntimePath != "" {
ort.SetSharedLibraryPath(discovered.RuntimePath)
}
}
// Initialize ONNX Runtime.
if err := ort.InitializeEnvironment(); err != nil {
return nil, fmt.Errorf("onnx init: %w", err)
}
// Load tokenizer.
maxSeqLen := cfg.MaxSeqLen
if maxSeqLen <= 0 {
maxSeqLen = 128
}
tokenizer, err := NewTokenizer(TokenizerConfig{
VocabPath: paths.VocabPath,
MaxLength: maxSeqLen,
})
if err != nil {
return nil, fmt.Errorf("onnx tokenizer: %w", err)
}
log.Printf("ONNX tokenizer loaded: %d tokens from %s", tokenizer.VocabSize(), paths.VocabPath)
// Create ONNX session.
// MiniLM inputs: input_ids [1, seq_len], attention_mask [1, seq_len], token_type_ids [1, seq_len]
// MiniLM output: last_hidden_state [1, seq_len, 384] → mean pool → [384]
inputNames := []string{"input_ids", "attention_mask", "token_type_ids"}
outputNames := []string{"last_hidden_state"}
session, err := ort.NewDynamicAdvancedSession(
paths.ModelPath,
inputNames,
outputNames,
nil, // default session options
)
if err != nil {
return nil, fmt.Errorf("onnx session: %w", err)
}
log.Printf("ONNX model loaded: %s (seq_len=%d)", paths.ModelPath, maxSeqLen)
return &Embedder{
session: session,
tokenizer: tokenizer,
dimension: 384, // MiniLM-L12-v2.
modelName: "MiniLM-L12-v2",
}, nil
}
// Embed computes a 384-dim embedding via ONNX inference.
func (e *Embedder) Embed(ctx context.Context, text string) ([]float64, error) {
e.mu.Lock()
defer e.mu.Unlock()
// Tokenize.
encoded := e.tokenizer.Encode(text)
seqLen := int64(len(encoded.InputIDs))
shape := ort.Shape{1, seqLen}
// Create input tensors.
inputIDsTensor, err := ort.NewTensor(shape, encoded.InputIDs)
if err != nil {
return nil, fmt.Errorf("create input_ids tensor: %w", err)
}
defer inputIDsTensor.Destroy()
attMaskTensor, err := ort.NewTensor(shape, encoded.AttentionMask)
if err != nil {
return nil, fmt.Errorf("create attention_mask tensor: %w", err)
}
defer attMaskTensor.Destroy()
tokenTypeTensor, err := ort.NewTensor(shape, encoded.TokenTypeIDs)
if err != nil {
return nil, fmt.Errorf("create token_type_ids tensor: %w", err)
}
defer tokenTypeTensor.Destroy()
// Create output tensor placeholder.
outputShape := ort.Shape{1, seqLen, int64(e.dimension)}
outputTensor, err := ort.NewEmptyTensor[float32](outputShape)
if err != nil {
return nil, fmt.Errorf("create output tensor: %w", err)
}
defer outputTensor.Destroy()
// Run inference.
err = e.session.Run(
[]ort.ArbitraryTensor{inputIDsTensor, attMaskTensor, tokenTypeTensor},
[]ort.ArbitraryTensor{outputTensor},
)
if err != nil {
return nil, fmt.Errorf("onnx inference: %w", err)
}
// Mean pooling over non-padded tokens.
rawOutput := outputTensor.GetData()
embedding := meanPool(rawOutput, encoded.AttentionMask, int(seqLen), e.dimension)
// L2 normalize.
l2Normalize(embedding)
return embedding, nil
}
// Dimension returns 384 (MiniLM-L12-v2).
func (e *Embedder) Dimension() int {
return e.dimension
}
// Name returns the embedder identifier.
func (e *Embedder) Name() string {
return fmt.Sprintf("onnx:%s", e.modelName)
}
// Mode returns FULL — ONNX provides neural embeddings.
func (e *Embedder) Mode() vectorstore.OracleMode {
return vectorstore.OracleModeFull
}
// Close releases ONNX resources.
func (e *Embedder) Close() error {
e.mu.Lock()
defer e.mu.Unlock()
if e.session != nil {
e.session.Destroy()
}
return ort.DestroyEnvironment()
}
// meanPool computes mean pooling over the hidden states,
// considering only non-padded positions (attention_mask=1).
func meanPool(hiddenStates []float32, attentionMask []int64, seqLen, dim int) []float64 {
result := make([]float64, dim)
var count float64
for i := 0; i < seqLen; i++ {
if attentionMask[i] == 0 {
continue
}
count++
offset := i * dim
for d := 0; d < dim; d++ {
result[d] += float64(hiddenStates[offset+d])
}
}
if count > 0 {
for d := range result {
result[d] /= count
}
}
return result
}
// l2Normalize normalizes a vector in-place to unit length.
func l2Normalize(vec []float64) {
var norm float64
for _, v := range vec {
norm += v * v
}
norm = math.Sqrt(norm)
if norm > 0 {
for i := range vec {
vec[i] /= norm
}
}
}