diff --git a/surfsense_backend/app/utils/document_converters.py b/surfsense_backend/app/utils/document_converters.py index 0cacdd1d3..ed52c1b7b 100644 --- a/surfsense_backend/app/utils/document_converters.py +++ b/surfsense_backend/app/utils/document_converters.py @@ -1,5 +1,6 @@ import hashlib import logging +import threading import warnings import numpy as np @@ -11,6 +12,12 @@ from app.prompts import SUMMARY_PROMPT_TEMPLATE logger = logging.getLogger(__name__) +# HuggingFace fast tokenizers (Rust-backed) are not thread-safe — concurrent +# access from multiple threads causes "RuntimeError: Already borrowed". +# This reentrant lock serialises tokenizer + embedding model access so that +# asyncio.to_thread calls from index_batch_parallel don't collide. +_embedding_lock = threading.RLock() + def _get_embedding_max_tokens() -> int: """Get the max token limit for the configured embedding model. @@ -36,23 +43,25 @@ def truncate_for_embedding(text: str) -> str: if len(text) // 3 <= max_tokens: return text - tokenizer = config.embedding_model_instance.get_tokenizer() - tokens = tokenizer.encode(text) - if len(tokens) <= max_tokens: - return text + with _embedding_lock: + tokenizer = config.embedding_model_instance.get_tokenizer() + tokens = tokenizer.encode(text) + if len(tokens) <= max_tokens: + return text - warnings.warn( - f"Truncating text from {len(tokens)} to {max_tokens} tokens for embedding.", - stacklevel=2, - ) - return tokenizer.decode(tokens[:max_tokens]) + warnings.warn( + f"Truncating text from {len(tokens)} to {max_tokens} tokens for embedding.", + stacklevel=2, + ) + return tokenizer.decode(tokens[:max_tokens]) def embed_text(text: str) -> np.ndarray: """Truncate text to fit and embed it. Drop-in replacement for ``config.embedding_model_instance.embed(text)`` that never exceeds the model's context window.""" - return config.embedding_model_instance.embed(truncate_for_embedding(text)) + with _embedding_lock: + return config.embedding_model_instance.embed(truncate_for_embedding(text)) def embed_texts(texts: list[str]) -> list[np.ndarray]: @@ -66,10 +75,11 @@ def embed_texts(texts: list[str]) -> list[np.ndarray]: """ if not texts: return [] - truncated = [truncate_for_embedding(t) for t in texts] - if config.is_local_embedding_model: - return [config.embedding_model_instance.embed(t) for t in truncated] - return config.embedding_model_instance.embed_batch(truncated) + with _embedding_lock: + truncated = [truncate_for_embedding(t) for t in texts] + if config.is_local_embedding_model: + return [config.embedding_model_instance.embed(t) for t in truncated] + return config.embedding_model_instance.embed_batch(truncated) def get_model_context_window(model_name: str) -> int: