mirror of
https://github.com/syntrex-lab/gomcp.git
synced 2026-04-27 05:16:22 +02:00
- Rename Go module: sentinel-community/gomcp -> syntrex/gomcp (50+ files) - Rename npm package: sentinel-dashboard -> syntrex-dashboard - Update Cargo.toml repository URL to syntrex/syntrex - Update all doc references from DmitrL-dev/AISecurity to syntrex - Add root Makefile (build-all, test-all, lint-all, clean-all) - Add MIT LICENSE - Add .editorconfig (Go/Rust/TS/C cross-language) - Add .github/workflows/ci.yml (Go + Rust + Dashboard) - Add dashboard next.config.ts and .env.example - Clean ARCHITECTURE.md: remove brain/immune/strike/micro-swarm, fix 61->67 engines
228 lines
5.6 KiB
Go
228 lines
5.6 KiB
Go
//go:build onnx
|
|
|
|
package onnx
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"math"
|
|
"sync"
|
|
|
|
ort "github.com/yalue/onnxruntime_go"
|
|
|
|
"github.com/syntrex/gomcp/internal/domain/vectorstore"
|
|
)
|
|
|
|
// Embedder implements vectorstore.Embedder using ONNX Runtime.
|
|
// Runs paraphrase-multilingual-MiniLM-L12-v2 inference locally.
|
|
type Embedder struct {
|
|
mu sync.Mutex
|
|
session *ort.DynamicAdvancedSession
|
|
tokenizer *Tokenizer
|
|
dimension int
|
|
modelName string
|
|
}
|
|
|
|
// Config holds ONNX embedder configuration.
|
|
type Config struct {
|
|
RlmDir string // Path to .rlm directory (for model discovery)
|
|
ModelPath string // Override: direct path to .onnx model
|
|
VocabPath string // Override: direct path to vocab.txt
|
|
MaxSeqLen int // Max sequence length (default: 128)
|
|
}
|
|
|
|
// NewEmbedder creates an ONNX-based embedder.
|
|
// Returns an error if the model or runtime cannot be loaded.
|
|
// Caller should fall back to FTS5Embedder on error.
|
|
func NewEmbedder(cfg Config) (*Embedder, error) {
|
|
// Discover model paths if not explicitly provided.
|
|
paths := &ModelPaths{
|
|
ModelPath: cfg.ModelPath,
|
|
VocabPath: cfg.VocabPath,
|
|
}
|
|
|
|
if paths.ModelPath == "" || paths.VocabPath == "" {
|
|
discovered, err := DiscoverModel(cfg.RlmDir)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("onnx discovery: %w", err)
|
|
}
|
|
if paths.ModelPath == "" {
|
|
paths.ModelPath = discovered.ModelPath
|
|
}
|
|
if paths.VocabPath == "" {
|
|
paths.VocabPath = discovered.VocabPath
|
|
}
|
|
// Set runtime path for ONNX Runtime initialization.
|
|
if discovered.RuntimePath != "" {
|
|
ort.SetSharedLibraryPath(discovered.RuntimePath)
|
|
}
|
|
}
|
|
|
|
// Initialize ONNX Runtime.
|
|
if err := ort.InitializeEnvironment(); err != nil {
|
|
return nil, fmt.Errorf("onnx init: %w", err)
|
|
}
|
|
|
|
// Load tokenizer.
|
|
maxSeqLen := cfg.MaxSeqLen
|
|
if maxSeqLen <= 0 {
|
|
maxSeqLen = 128
|
|
}
|
|
|
|
tokenizer, err := NewTokenizer(TokenizerConfig{
|
|
VocabPath: paths.VocabPath,
|
|
MaxLength: maxSeqLen,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("onnx tokenizer: %w", err)
|
|
}
|
|
|
|
log.Printf("ONNX tokenizer loaded: %d tokens from %s", tokenizer.VocabSize(), paths.VocabPath)
|
|
|
|
// Create ONNX session.
|
|
// MiniLM inputs: input_ids [1, seq_len], attention_mask [1, seq_len], token_type_ids [1, seq_len]
|
|
// MiniLM output: last_hidden_state [1, seq_len, 384] → mean pool → [384]
|
|
inputNames := []string{"input_ids", "attention_mask", "token_type_ids"}
|
|
outputNames := []string{"last_hidden_state"}
|
|
|
|
session, err := ort.NewDynamicAdvancedSession(
|
|
paths.ModelPath,
|
|
inputNames,
|
|
outputNames,
|
|
nil, // default session options
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("onnx session: %w", err)
|
|
}
|
|
|
|
log.Printf("ONNX model loaded: %s (seq_len=%d)", paths.ModelPath, maxSeqLen)
|
|
|
|
return &Embedder{
|
|
session: session,
|
|
tokenizer: tokenizer,
|
|
dimension: 384, // MiniLM-L12-v2.
|
|
modelName: "MiniLM-L12-v2",
|
|
}, nil
|
|
}
|
|
|
|
// Embed computes a 384-dim embedding via ONNX inference.
|
|
func (e *Embedder) Embed(ctx context.Context, text string) ([]float64, error) {
|
|
e.mu.Lock()
|
|
defer e.mu.Unlock()
|
|
|
|
// Tokenize.
|
|
encoded := e.tokenizer.Encode(text)
|
|
|
|
seqLen := int64(len(encoded.InputIDs))
|
|
shape := ort.Shape{1, seqLen}
|
|
|
|
// Create input tensors.
|
|
inputIDsTensor, err := ort.NewTensor(shape, encoded.InputIDs)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create input_ids tensor: %w", err)
|
|
}
|
|
defer inputIDsTensor.Destroy()
|
|
|
|
attMaskTensor, err := ort.NewTensor(shape, encoded.AttentionMask)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create attention_mask tensor: %w", err)
|
|
}
|
|
defer attMaskTensor.Destroy()
|
|
|
|
tokenTypeTensor, err := ort.NewTensor(shape, encoded.TokenTypeIDs)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create token_type_ids tensor: %w", err)
|
|
}
|
|
defer tokenTypeTensor.Destroy()
|
|
|
|
// Create output tensor placeholder.
|
|
outputShape := ort.Shape{1, seqLen, int64(e.dimension)}
|
|
outputTensor, err := ort.NewEmptyTensor[float32](outputShape)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create output tensor: %w", err)
|
|
}
|
|
defer outputTensor.Destroy()
|
|
|
|
// Run inference.
|
|
err = e.session.Run(
|
|
[]ort.ArbitraryTensor{inputIDsTensor, attMaskTensor, tokenTypeTensor},
|
|
[]ort.ArbitraryTensor{outputTensor},
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("onnx inference: %w", err)
|
|
}
|
|
|
|
// Mean pooling over non-padded tokens.
|
|
rawOutput := outputTensor.GetData()
|
|
embedding := meanPool(rawOutput, encoded.AttentionMask, int(seqLen), e.dimension)
|
|
|
|
// L2 normalize.
|
|
l2Normalize(embedding)
|
|
|
|
return embedding, nil
|
|
}
|
|
|
|
// Dimension returns 384 (MiniLM-L12-v2).
|
|
func (e *Embedder) Dimension() int {
|
|
return e.dimension
|
|
}
|
|
|
|
// Name returns the embedder identifier.
|
|
func (e *Embedder) Name() string {
|
|
return fmt.Sprintf("onnx:%s", e.modelName)
|
|
}
|
|
|
|
// Mode returns FULL — ONNX provides neural embeddings.
|
|
func (e *Embedder) Mode() vectorstore.OracleMode {
|
|
return vectorstore.OracleModeFull
|
|
}
|
|
|
|
// Close releases ONNX resources.
|
|
func (e *Embedder) Close() error {
|
|
e.mu.Lock()
|
|
defer e.mu.Unlock()
|
|
if e.session != nil {
|
|
e.session.Destroy()
|
|
}
|
|
return ort.DestroyEnvironment()
|
|
}
|
|
|
|
// meanPool computes mean pooling over the hidden states,
|
|
// considering only non-padded positions (attention_mask=1).
|
|
func meanPool(hiddenStates []float32, attentionMask []int64, seqLen, dim int) []float64 {
|
|
result := make([]float64, dim)
|
|
var count float64
|
|
|
|
for i := 0; i < seqLen; i++ {
|
|
if attentionMask[i] == 0 {
|
|
continue
|
|
}
|
|
count++
|
|
offset := i * dim
|
|
for d := 0; d < dim; d++ {
|
|
result[d] += float64(hiddenStates[offset+d])
|
|
}
|
|
}
|
|
|
|
if count > 0 {
|
|
for d := range result {
|
|
result[d] /= count
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// l2Normalize normalizes a vector in-place to unit length.
|
|
func l2Normalize(vec []float64) {
|
|
var norm float64
|
|
for _, v := range vec {
|
|
norm += v * v
|
|
}
|
|
norm = math.Sqrt(norm)
|
|
if norm > 0 {
|
|
for i := range vec {
|
|
vec[i] /= norm
|
|
}
|
|
}
|
|
}
|