mirror of
https://github.com/Kaelio/ktx.git
synced 2026-06-16 08:25:14 +02:00
Initial open-source release
This commit is contained in:
commit
1a42152e6f
1199 changed files with 257054 additions and 0 deletions
172
python/klo-daemon/src/klo_daemon/embeddings.py
Normal file
172
python/klo-daemon/src/klo_daemon/embeddings.py
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
"""Portable embedding compute helpers for KLO daemon."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_SENTENCE_TRANSFORMER_MODEL = "all-MiniLM-L6-v2"
|
||||
DEFAULT_EMBEDDING_DIMENSIONS = 384
|
||||
DEFAULT_MAX_BATCH_SIZE = 100
|
||||
|
||||
|
||||
class EmbeddingProvider(Protocol):
|
||||
"""Provider interface for local embedding compute."""
|
||||
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int: ...
|
||||
|
||||
@property
|
||||
def max_batch_size(self) -> int: ...
|
||||
|
||||
def encode(self, texts: list[str]) -> list[list[float]]: ...
|
||||
|
||||
|
||||
class ComputeEmbeddingRequest(BaseModel):
|
||||
"""Request schema for computing a single embedding."""
|
||||
|
||||
text: str = Field(..., description="Text to compute embedding for", min_length=1)
|
||||
|
||||
|
||||
class ComputeEmbeddingResponse(BaseModel):
|
||||
"""Response schema for single embedding computation."""
|
||||
|
||||
embedding: list[float] = Field(..., description="384-dimensional embedding vector")
|
||||
|
||||
|
||||
class ComputeEmbeddingBulkRequest(BaseModel):
|
||||
"""Request schema for computing multiple embeddings."""
|
||||
|
||||
texts: list[str] = Field(
|
||||
...,
|
||||
description="List of texts to compute embeddings for",
|
||||
min_length=1,
|
||||
max_length=DEFAULT_MAX_BATCH_SIZE,
|
||||
)
|
||||
|
||||
|
||||
class ComputeEmbeddingBulkResponse(BaseModel):
|
||||
"""Response schema for bulk embedding computation."""
|
||||
|
||||
embeddings: list[list[float]] = Field(
|
||||
...,
|
||||
description="List of 384-dimensional embedding vectors",
|
||||
)
|
||||
|
||||
|
||||
class SentenceTransformersEmbeddingProvider:
|
||||
"""Lazy sentence-transformers provider for local embeddings."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = DEFAULT_SENTENCE_TRANSFORMER_MODEL,
|
||||
model: SentenceTransformer | None = None,
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self._model = model
|
||||
self._model_lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "sentence-transformers"
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return DEFAULT_EMBEDDING_DIMENSIONS
|
||||
|
||||
@property
|
||||
def max_batch_size(self) -> int:
|
||||
return DEFAULT_MAX_BATCH_SIZE
|
||||
|
||||
def _get_model(self) -> SentenceTransformer:
|
||||
if self._model is not None:
|
||||
return self._model
|
||||
|
||||
with self._model_lock:
|
||||
if self._model is None:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logger.info("Loading SentenceTransformer model: %s", self.model_name)
|
||||
self._model = SentenceTransformer(self.model_name)
|
||||
logger.info("SentenceTransformer model loaded successfully")
|
||||
|
||||
return self._model
|
||||
|
||||
def encode(self, texts: list[str]) -> list[list[float]]:
|
||||
model = self._get_model()
|
||||
if len(texts) == 1:
|
||||
raw_single = model.encode(texts[0]).tolist()
|
||||
return [[float(value) for value in raw_single]]
|
||||
|
||||
raw_bulk = model.encode(texts).tolist()
|
||||
return [[float(value) for value in embedding] for embedding in raw_bulk]
|
||||
|
||||
|
||||
_default_provider: SentenceTransformersEmbeddingProvider | None = None
|
||||
_default_provider_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_default_embedding_provider() -> SentenceTransformersEmbeddingProvider:
|
||||
"""Return the process-wide default embedding provider."""
|
||||
|
||||
global _default_provider
|
||||
|
||||
if _default_provider is not None:
|
||||
return _default_provider
|
||||
|
||||
with _default_provider_lock:
|
||||
if _default_provider is None:
|
||||
_default_provider = SentenceTransformersEmbeddingProvider()
|
||||
|
||||
return _default_provider
|
||||
|
||||
|
||||
def _validate_texts(texts: list[str], max_batch_size: int) -> None:
|
||||
if not texts:
|
||||
raise ValueError("Texts array must not be empty")
|
||||
if len(texts) > max_batch_size:
|
||||
raise ValueError(f"Maximum {max_batch_size} texts allowed per batch")
|
||||
|
||||
empty_indices = [
|
||||
index for index, text in enumerate(texts) if not text or not text.strip()
|
||||
]
|
||||
if empty_indices:
|
||||
joined_indices = ", ".join(str(index) for index in empty_indices)
|
||||
raise ValueError(f"Empty texts found at indices: {joined_indices}")
|
||||
|
||||
|
||||
def compute_embedding_response(
|
||||
request: ComputeEmbeddingRequest,
|
||||
provider: EmbeddingProvider | None = None,
|
||||
) -> ComputeEmbeddingResponse:
|
||||
"""Compute one embedding from a request model."""
|
||||
|
||||
selected_provider = provider or get_default_embedding_provider()
|
||||
_validate_texts([request.text], selected_provider.max_batch_size)
|
||||
return ComputeEmbeddingResponse(
|
||||
embedding=selected_provider.encode([request.text])[0]
|
||||
)
|
||||
|
||||
|
||||
def compute_embedding_bulk_response(
|
||||
request: ComputeEmbeddingBulkRequest,
|
||||
provider: EmbeddingProvider | None = None,
|
||||
) -> ComputeEmbeddingBulkResponse:
|
||||
"""Compute multiple embeddings from a request model."""
|
||||
|
||||
selected_provider = provider or get_default_embedding_provider()
|
||||
_validate_texts(request.texts, selected_provider.max_batch_size)
|
||||
return ComputeEmbeddingBulkResponse(
|
||||
embeddings=selected_provider.encode(request.texts)
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue