mirror of
https://github.com/Kaelio/ktx.git
synced 2026-06-07 07:55:13 +02:00
172 lines
5.1 KiB
Python
172 lines
5.1 KiB
Python
"""Portable embedding compute helpers for KTX daemon."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import threading
|
|
from typing import TYPE_CHECKING, Protocol
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
if TYPE_CHECKING:
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_SENTENCE_TRANSFORMER_MODEL = "all-MiniLM-L6-v2"
|
|
DEFAULT_EMBEDDING_DIMENSIONS = 384
|
|
DEFAULT_MAX_BATCH_SIZE = 100
|
|
|
|
|
|
class EmbeddingProvider(Protocol):
|
|
"""Provider interface for local embedding compute."""
|
|
|
|
@property
|
|
def name(self) -> str: ...
|
|
|
|
@property
|
|
def dimensions(self) -> int: ...
|
|
|
|
@property
|
|
def max_batch_size(self) -> int: ...
|
|
|
|
def encode(self, texts: list[str]) -> list[list[float]]: ...
|
|
|
|
|
|
class ComputeEmbeddingRequest(BaseModel):
|
|
"""Request schema for computing a single embedding."""
|
|
|
|
text: str = Field(..., description="Text to compute embedding for", min_length=1)
|
|
|
|
|
|
class ComputeEmbeddingResponse(BaseModel):
|
|
"""Response schema for single embedding computation."""
|
|
|
|
embedding: list[float] = Field(..., description="384-dimensional embedding vector")
|
|
|
|
|
|
class ComputeEmbeddingBulkRequest(BaseModel):
|
|
"""Request schema for computing multiple embeddings."""
|
|
|
|
texts: list[str] = Field(
|
|
...,
|
|
description="List of texts to compute embeddings for",
|
|
min_length=1,
|
|
max_length=DEFAULT_MAX_BATCH_SIZE,
|
|
)
|
|
|
|
|
|
class ComputeEmbeddingBulkResponse(BaseModel):
|
|
"""Response schema for bulk embedding computation."""
|
|
|
|
embeddings: list[list[float]] = Field(
|
|
...,
|
|
description="List of 384-dimensional embedding vectors",
|
|
)
|
|
|
|
|
|
class SentenceTransformersEmbeddingProvider:
|
|
"""Lazy sentence-transformers provider for local embeddings."""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = DEFAULT_SENTENCE_TRANSFORMER_MODEL,
|
|
model: SentenceTransformer | None = None,
|
|
) -> None:
|
|
self.model_name = model_name
|
|
self._model = model
|
|
self._model_lock = threading.Lock()
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return "sentence-transformers"
|
|
|
|
@property
|
|
def dimensions(self) -> int:
|
|
return DEFAULT_EMBEDDING_DIMENSIONS
|
|
|
|
@property
|
|
def max_batch_size(self) -> int:
|
|
return DEFAULT_MAX_BATCH_SIZE
|
|
|
|
def _get_model(self) -> SentenceTransformer:
|
|
if self._model is not None:
|
|
return self._model
|
|
|
|
with self._model_lock:
|
|
if self._model is None:
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
logger.info("Loading SentenceTransformer model: %s", self.model_name)
|
|
self._model = SentenceTransformer(self.model_name)
|
|
logger.info("SentenceTransformer model loaded successfully")
|
|
|
|
return self._model
|
|
|
|
def encode(self, texts: list[str]) -> list[list[float]]:
|
|
model = self._get_model()
|
|
if len(texts) == 1:
|
|
raw_single = model.encode(texts[0]).tolist()
|
|
return [[float(value) for value in raw_single]]
|
|
|
|
raw_bulk = model.encode(texts).tolist()
|
|
return [[float(value) for value in embedding] for embedding in raw_bulk]
|
|
|
|
|
|
_default_provider: SentenceTransformersEmbeddingProvider | None = None
|
|
_default_provider_lock = threading.Lock()
|
|
|
|
|
|
def get_default_embedding_provider() -> SentenceTransformersEmbeddingProvider:
|
|
"""Return the process-wide default embedding provider."""
|
|
|
|
global _default_provider
|
|
|
|
if _default_provider is not None:
|
|
return _default_provider
|
|
|
|
with _default_provider_lock:
|
|
if _default_provider is None:
|
|
_default_provider = SentenceTransformersEmbeddingProvider()
|
|
|
|
return _default_provider
|
|
|
|
|
|
def _validate_texts(texts: list[str], max_batch_size: int) -> None:
|
|
if not texts:
|
|
raise ValueError("Texts array must not be empty")
|
|
if len(texts) > max_batch_size:
|
|
raise ValueError(f"Maximum {max_batch_size} texts allowed per batch")
|
|
|
|
empty_indices = [
|
|
index for index, text in enumerate(texts) if not text or not text.strip()
|
|
]
|
|
if empty_indices:
|
|
joined_indices = ", ".join(str(index) for index in empty_indices)
|
|
raise ValueError(f"Empty texts found at indices: {joined_indices}")
|
|
|
|
|
|
def compute_embedding_response(
|
|
request: ComputeEmbeddingRequest,
|
|
provider: EmbeddingProvider | None = None,
|
|
) -> ComputeEmbeddingResponse:
|
|
"""Compute one embedding from a request model."""
|
|
|
|
selected_provider = provider or get_default_embedding_provider()
|
|
_validate_texts([request.text], selected_provider.max_batch_size)
|
|
return ComputeEmbeddingResponse(
|
|
embedding=selected_provider.encode([request.text])[0]
|
|
)
|
|
|
|
|
|
def compute_embedding_bulk_response(
|
|
request: ComputeEmbeddingBulkRequest,
|
|
provider: EmbeddingProvider | None = None,
|
|
) -> ComputeEmbeddingBulkResponse:
|
|
"""Compute multiple embeddings from a request model."""
|
|
|
|
selected_provider = provider or get_default_embedding_provider()
|
|
_validate_texts(request.texts, selected_provider.max_batch_size)
|
|
return ComputeEmbeddingBulkResponse(
|
|
embeddings=selected_provider.encode(request.texts)
|
|
)
|