feat: implement thread-safe embedding access in document converters

- Added a reentrant lock to ensure thread-safe access to the tokenizer and embedding model, preventing runtime errors during concurrent operations.
- Updated the `truncate_for_embedding` and `embed_text` functions to utilize the lock, ensuring safe execution in multi-threaded environments.
- Enhanced the `embed_texts` function to maintain thread safety while processing multiple texts for embedding.
This commit is contained in:
Anish Sarkar 2026-03-27 11:31:00 +05:30
parent db6dd058dd
commit 683a4c17dd

View file

@ -1,5 +1,6 @@
import hashlib
import logging
import threading
import warnings
import numpy as np
@ -11,6 +12,12 @@ from app.prompts import SUMMARY_PROMPT_TEMPLATE
logger = logging.getLogger(__name__)
# HuggingFace fast tokenizers (Rust-backed) are not thread-safe — concurrent
# access from multiple threads causes "RuntimeError: Already borrowed".
# This reentrant lock serialises tokenizer + embedding model access so that
# asyncio.to_thread calls from index_batch_parallel don't collide.
_embedding_lock = threading.RLock()
def _get_embedding_max_tokens() -> int:
"""Get the max token limit for the configured embedding model.
@ -36,23 +43,25 @@ def truncate_for_embedding(text: str) -> str:
if len(text) // 3 <= max_tokens:
return text
tokenizer = config.embedding_model_instance.get_tokenizer()
tokens = tokenizer.encode(text)
if len(tokens) <= max_tokens:
return text
with _embedding_lock:
tokenizer = config.embedding_model_instance.get_tokenizer()
tokens = tokenizer.encode(text)
if len(tokens) <= max_tokens:
return text
warnings.warn(
f"Truncating text from {len(tokens)} to {max_tokens} tokens for embedding.",
stacklevel=2,
)
return tokenizer.decode(tokens[:max_tokens])
warnings.warn(
f"Truncating text from {len(tokens)} to {max_tokens} tokens for embedding.",
stacklevel=2,
)
return tokenizer.decode(tokens[:max_tokens])
def embed_text(text: str) -> np.ndarray:
"""Truncate text to fit and embed it. Drop-in replacement for
``config.embedding_model_instance.embed(text)`` that never exceeds the
model's context window."""
return config.embedding_model_instance.embed(truncate_for_embedding(text))
with _embedding_lock:
return config.embedding_model_instance.embed(truncate_for_embedding(text))
def embed_texts(texts: list[str]) -> list[np.ndarray]:
@ -66,10 +75,11 @@ def embed_texts(texts: list[str]) -> list[np.ndarray]:
"""
if not texts:
return []
truncated = [truncate_for_embedding(t) for t in texts]
if config.is_local_embedding_model:
return [config.embedding_model_instance.embed(t) for t in truncated]
return config.embedding_model_instance.embed_batch(truncated)
with _embedding_lock:
truncated = [truncate_for_embedding(t) for t in texts]
if config.is_local_embedding_model:
return [config.embedding_model_instance.embed(t) for t in truncated]
return config.embedding_model_instance.embed_batch(truncated)
def get_model_context_window(model_name: str) -> int: