mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
feat: implement thread-safe embedding access in document converters
- Added a reentrant lock to ensure thread-safe access to the tokenizer and embedding model, preventing runtime errors during concurrent operations. - Updated the `truncate_for_embedding` and `embed_text` functions to utilize the lock, ensuring safe execution in multi-threaded environments. - Enhanced the `embed_texts` function to maintain thread safety while processing multiple texts for embedding.
This commit is contained in:
parent
db6dd058dd
commit
683a4c17dd
1 changed files with 24 additions and 14 deletions
|
|
@ -1,5 +1,6 @@
|
|||
import hashlib
|
||||
import logging
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -11,6 +12,12 @@ from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# HuggingFace fast tokenizers (Rust-backed) are not thread-safe — concurrent
|
||||
# access from multiple threads causes "RuntimeError: Already borrowed".
|
||||
# This reentrant lock serialises tokenizer + embedding model access so that
|
||||
# asyncio.to_thread calls from index_batch_parallel don't collide.
|
||||
_embedding_lock = threading.RLock()
|
||||
|
||||
|
||||
def _get_embedding_max_tokens() -> int:
|
||||
"""Get the max token limit for the configured embedding model.
|
||||
|
|
@ -36,23 +43,25 @@ def truncate_for_embedding(text: str) -> str:
|
|||
if len(text) // 3 <= max_tokens:
|
||||
return text
|
||||
|
||||
tokenizer = config.embedding_model_instance.get_tokenizer()
|
||||
tokens = tokenizer.encode(text)
|
||||
if len(tokens) <= max_tokens:
|
||||
return text
|
||||
with _embedding_lock:
|
||||
tokenizer = config.embedding_model_instance.get_tokenizer()
|
||||
tokens = tokenizer.encode(text)
|
||||
if len(tokens) <= max_tokens:
|
||||
return text
|
||||
|
||||
warnings.warn(
|
||||
f"Truncating text from {len(tokens)} to {max_tokens} tokens for embedding.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return tokenizer.decode(tokens[:max_tokens])
|
||||
warnings.warn(
|
||||
f"Truncating text from {len(tokens)} to {max_tokens} tokens for embedding.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return tokenizer.decode(tokens[:max_tokens])
|
||||
|
||||
|
||||
def embed_text(text: str) -> np.ndarray:
|
||||
"""Truncate text to fit and embed it. Drop-in replacement for
|
||||
``config.embedding_model_instance.embed(text)`` that never exceeds the
|
||||
model's context window."""
|
||||
return config.embedding_model_instance.embed(truncate_for_embedding(text))
|
||||
with _embedding_lock:
|
||||
return config.embedding_model_instance.embed(truncate_for_embedding(text))
|
||||
|
||||
|
||||
def embed_texts(texts: list[str]) -> list[np.ndarray]:
|
||||
|
|
@ -66,10 +75,11 @@ def embed_texts(texts: list[str]) -> list[np.ndarray]:
|
|||
"""
|
||||
if not texts:
|
||||
return []
|
||||
truncated = [truncate_for_embedding(t) for t in texts]
|
||||
if config.is_local_embedding_model:
|
||||
return [config.embedding_model_instance.embed(t) for t in truncated]
|
||||
return config.embedding_model_instance.embed_batch(truncated)
|
||||
with _embedding_lock:
|
||||
truncated = [truncate_for_embedding(t) for t in texts]
|
||||
if config.is_local_embedding_model:
|
||||
return [config.embedding_model_instance.embed(t) for t in truncated]
|
||||
return config.embedding_model_instance.embed_batch(truncated)
|
||||
|
||||
|
||||
def get_model_context_window(model_name: str) -> int:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue