Merge pull request #1508 from CREDO23/fix/indexing-batch-chunk-insert

fix(indexing): batch chunk inserts and truncate notification titles
This commit is contained in:
Rohan Verma 2026-06-17 15:26:02 -07:00 committed by GitHub
commit 6a45f24f98
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 364 additions and 124 deletions

View file

@ -959,6 +959,9 @@ class Config:
CHUNK_RECONCILE_ENABLED = (
os.getenv("CHUNK_RECONCILE_ENABLED", "true").strip().lower() == "true"
)
INDEXING_CHUNK_INSERT_BATCH_SIZE = int(
os.getenv("INDEXING_CHUNK_INSERT_BATCH_SIZE", "200")
)
# Proxy provider selection. Maps to a ProxyProvider implementation registered
# in app/utils/proxy/registry.py. Add new vendors there and switch via this var.

View file

@ -1,12 +1,12 @@
import contextlib
import logging
import time
from datetime import UTC, datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import object_session
from sqlalchemy.orm.attributes import set_committed_value
from app.db import Document, DocumentStatus
from app.db import Chunk, Document, DocumentStatus
logger = logging.getLogger(__name__)
@ -22,7 +22,6 @@ async def rollback_and_persist_failure(
try:
await session.rollback()
except Exception:
# Session is completely dead; surface it but never raise.
logger.warning(
"Rollback failed; cannot persist failed status for document %s",
getattr(document, "id", "unknown"),
@ -35,8 +34,6 @@ async def rollback_and_persist_failure(
document.status = DocumentStatus.failed(message)
await session.commit()
except Exception:
# Best-effort: the document stays non-ready and is retried next sync.
# Log it so a permanently-stuck document is at least traceable.
logger.warning(
"Could not persist failed status for document %s; will retry next sync",
getattr(document, "id", "unknown"),
@ -46,12 +43,60 @@ async def rollback_and_persist_failure(
await session.rollback()
def attach_chunks_to_document(document: Document, chunks: list) -> None:
"""Assign chunks to a document without triggering SQLAlchemy async lazy loading."""
async def persist_scratch_index(
session: AsyncSession,
document: Document,
content: str,
chunks: list[Chunk],
*,
batch_size: int,
perf: logging.Logger,
) -> None:
"""Commit document content first, then chunk rows in batches, then mark ready."""
if document.id is None:
raise ValueError("document.id is required to persist chunks")
document.content = content
document.updated_at = datetime.now(UTC)
await session.commit()
t_persist = time.perf_counter()
total = len(chunks)
if total == 0:
set_committed_value(document, "chunks", [])
document.status = DocumentStatus.ready()
document.updated_at = datetime.now(UTC)
await session.commit()
return
effective_batch = total if batch_size <= 0 else batch_size
num_batches = (total + effective_batch - 1) // effective_batch
doc_id = document.id
for batch_idx, start in enumerate(range(0, total, effective_batch), start=1):
batch = chunks[start : start + effective_batch]
t_batch = time.perf_counter()
for chunk in batch:
chunk.document_id = doc_id
session.add_all(batch)
await session.commit()
perf.info(
"[indexing] chunk batch doc=%d batch=%d/%d rows=%d in %.3fs",
doc_id,
batch_idx,
num_batches,
len(batch),
time.perf_counter() - t_batch,
)
set_committed_value(document, "chunks", chunks)
session = object_session(document)
if session is not None:
if document.id is not None:
for chunk in chunks:
chunk.document_id = document.id
session.add_all(chunks)
document.status = DocumentStatus.ready()
document.updated_at = datetime.now(UTC)
await session.commit()
perf.info(
"[indexing] chunk persist doc=%d chunks=%d batches=%d in %.3fs",
doc_id,
total,
num_batches,
time.perf_counter() - t_persist,
)

View file

@ -29,7 +29,7 @@ from app.indexing_pipeline.document_hashing import (
compute_unique_identifier_hash,
)
from app.indexing_pipeline.document_persistence import (
attach_chunks_to_document,
persist_scratch_index,
rollback_and_persist_failure,
)
from app.indexing_pipeline.exceptions import (
@ -387,21 +387,37 @@ class IndexingPipelineService:
chunk_count = await self._reindex_incrementally(
document, content, connector_doc, existing
)
perf.info(
"[indexing] chunk+embed doc=%d chunks=%d in %.3fs",
document.id,
chunk_count,
time.perf_counter() - t_step,
)
document.content = content
document.updated_at = datetime.now(UTC)
document.status = DocumentStatus.ready()
await self.session.commit()
else:
chunk_count = await self._reindex_from_scratch(
from app.config import config
chunks = await self._reindex_from_scratch(
document, content, connector_doc
)
perf.info(
"[indexing] chunk+embed doc=%d chunks=%d in %.3fs",
document.id,
chunk_count,
time.perf_counter() - t_step,
)
document.content = content
document.updated_at = datetime.now(UTC)
document.status = DocumentStatus.ready()
await self.session.commit()
chunk_count = len(chunks)
perf.info(
"[indexing] chunk+embed doc=%d chunks=%d in %.3fs",
document.id,
chunk_count,
time.perf_counter() - t_step,
)
await persist_scratch_index(
self.session,
document,
content,
chunks,
batch_size=config.INDEXING_CHUNK_INSERT_BATCH_SIZE,
perf=perf,
)
perf.info(
"[indexing] index TOTAL doc=%d chunks=%d in %.3fs",
document.id,
@ -484,8 +500,7 @@ class IndexingPipelineService:
async def _reindex_from_scratch(
self, document: Document, content: str, connector_doc: ConnectorDocument
) -> int:
"""First index (or kill-switched re-index): cache-aware full chunk+embed."""
) -> list[Chunk]:
await self.session.execute(
delete(Chunk).where(Chunk.document_id == document.id)
)
@ -495,13 +510,11 @@ class IndexingPipelineService:
use_code_chunker=connector_doc.should_use_code_chunker,
)
chunks = [
document.embedding = summary_embedding
return [
Chunk(content=text, embedding=emb, position=i)
for i, (text, emb) in enumerate(chunk_pairs)
]
document.embedding = summary_embedding
attach_chunks_to_document(document, chunks)
return len(chunks)
async def _reindex_incrementally(
self,

View file

@ -2,6 +2,9 @@
from __future__ import annotations
# Matches notifications.title VARCHAR(200).
TITLE_MAX_LENGTH = 200
# Notifications newer than this are live-synced; older ones load via the list endpoint.
SYNC_WINDOW_DAYS = 14

View file

@ -28,7 +28,7 @@ class DocumentProcessingNotificationHandler(BaseNotificationHandler):
) -> Notification:
"""Open the notification when document processing is queued."""
operation_id = msg.operation_id(document_type, document_name, search_space_id)
title = f"Processing: {document_name}"
title = msg.started_title(document_name)
message = "Waiting in queue"
metadata = {

View file

@ -6,6 +6,8 @@ import hashlib
from datetime import UTC, datetime
from typing import Any
from app.notifications.service.messages.text import format_title
def operation_id(document_type: str, filename: str, search_space_id: int) -> str:
"""Build a unique id for a document processing run."""
@ -14,6 +16,11 @@ def operation_id(document_type: str, filename: str, search_space_id: int) -> str
return f"doc_{document_type}_{search_space_id}_{timestamp}_{filename_hash}"
def started_title(document_name: str) -> str:
"""Title shown when document processing is queued."""
return format_title("Processing: ", document_name)
def progress(
stage: str,
stage_message: str | None = None,
@ -44,11 +51,11 @@ def completion(
) -> tuple[str, str, str, dict[str, Any]]:
"""Compute the final title, message, status, and metadata for a finished run."""
if error_message:
title = f"Failed: {document_name}"
title = format_title("Failed: ", document_name)
message = f"Processing failed: {error_message}"
status = "failed"
else:
title = f"Ready: {document_name}"
title = format_title("Ready: ", document_name)
message = "Now searchable!"
status = "completed"

View file

@ -2,7 +2,21 @@
from __future__ import annotations
from app.notifications.constants import TITLE_MAX_LENGTH
def truncate(text: str, limit: int) -> str:
"""Return ``text`` capped at ``limit`` chars, appending an ellipsis if cut."""
return text[:limit] + "..." if len(text) > limit else text
def format_title(prefix: str, text: str, *, max_length: int = TITLE_MAX_LENGTH) -> str:
"""Build a notification title that fits ``max_length`` including ``prefix``."""
budget = max_length - len(prefix)
if budget <= 0:
return prefix[:max_length]
if len(text) <= budget:
return f"{prefix}{text}"
if budget <= 3:
return f"{prefix}{text[:budget]}"
return f"{prefix}{text[: budget - 3]}..."

View file

@ -602,23 +602,29 @@ async def _process_file_upload(
# Create notification for document processing
logger.info(f"[_process_file_upload] Creating notification for: {filename}")
notification = (
await NotificationService.document_processing.notify_processing_started(
session=session,
user_id=UUID(user_id),
document_type="FILE",
document_name=filename,
search_space_id=search_space_id,
file_size=file_size,
notification = None
heartbeat_task = None
try:
notification = (
await NotificationService.document_processing.notify_processing_started(
session=session,
user_id=UUID(user_id),
document_type="FILE",
document_name=filename,
search_space_id=search_space_id,
file_size=file_size,
)
)
logger.info(
f"[_process_file_upload] Notification created with ID: {notification.id}"
)
_start_heartbeat(notification.id)
heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification.id))
except Exception:
logger.warning(
f"[_process_file_upload] Failed to create notification for: {filename}",
exc_info=True,
)
)
logger.info(
f"[_process_file_upload] Notification created with ID: {notification.id if notification else 'None'}"
)
# Start Redis heartbeat for stale task detection
_start_heartbeat(notification.id)
heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification.id))
log_entry = await task_logger.log_task_start(
task_name="process_file_upload",
@ -646,23 +652,21 @@ async def _process_file_upload(
# Update notification on success
if result:
await (
NotificationService.document_processing.notify_processing_completed(
if notification:
await NotificationService.document_processing.notify_processing_completed(
session=session,
notification=notification,
document_id=result.id,
chunks_count=None,
)
)
else:
# Duplicate detected
await (
NotificationService.document_processing.notify_processing_completed(
if notification:
await NotificationService.document_processing.notify_processing_completed(
session=session,
notification=notification,
error_message="Document already exists (duplicate)",
)
)
except Exception as e:
# Import here to avoid circular dependencies
@ -691,13 +695,13 @@ async def _process_file_upload(
error_message = str(credit_error)
# Create a dedicated insufficient credits notification
try:
# First, mark the processing notification as failed
await session.refresh(notification)
await NotificationService.document_processing.notify_processing_completed(
session=session,
notification=notification,
error_message="Insufficient credits",
)
if notification:
await session.refresh(notification)
await NotificationService.document_processing.notify_processing_completed(
session=session,
notification=notification,
error_message="Insufficient credits",
)
# Then create a separate insufficient_credits notification for better UX
await NotificationService.insufficient_credits.notify_insufficient_credits(
@ -717,12 +721,13 @@ async def _process_file_upload(
# HTTPException with page limit message but no detailed cause
error_message = str(e.detail)
try:
await session.refresh(notification)
await NotificationService.document_processing.notify_processing_completed(
session=session,
notification=notification,
error_message=error_message,
)
if notification:
await session.refresh(notification)
await NotificationService.document_processing.notify_processing_completed(
session=session,
notification=notification,
error_message=error_message,
)
except Exception as notif_error:
logger.error(
f"Failed to update notification on failure: {notif_error!s}"
@ -731,13 +736,13 @@ async def _process_file_upload(
error_message = str(e)[:100]
# Update notification on failure - wrapped in try-except to ensure it doesn't fail silently
try:
# Refresh notification to ensure it's not stale after any rollback
await session.refresh(notification)
await NotificationService.document_processing.notify_processing_completed(
session=session,
notification=notification,
error_message=error_message,
)
if notification:
await session.refresh(notification)
await NotificationService.document_processing.notify_processing_completed(
session=session,
notification=notification,
error_message=error_message,
)
except Exception as notif_error:
logger.error(
f"Failed to update notification on failure: {notif_error!s}"
@ -753,8 +758,10 @@ async def _process_file_upload(
raise
finally:
# Stop heartbeat — key deleted on success, expires on crash
heartbeat_task.cancel()
_stop_heartbeat(notification.id)
if heartbeat_task:
heartbeat_task.cancel()
if notification:
_stop_heartbeat(notification.id)
@celery_app.task(name="process_file_upload_with_document", bind=True)
@ -894,29 +901,36 @@ async def _process_file_with_document(
logger.info(
f"[_process_file_with_document] Creating notification for: {filename}"
)
notification = (
await NotificationService.document_processing.notify_processing_started(
session=session,
user_id=UUID(user_id),
document_type="FILE",
document_name=filename,
search_space_id=search_space_id,
file_size=file_size,
notification = None
heartbeat_task = None
try:
notification = (
await NotificationService.document_processing.notify_processing_started(
session=session,
user_id=UUID(user_id),
document_type="FILE",
document_name=filename,
search_space_id=search_space_id,
file_size=file_size,
)
)
)
# Store document_id in notification metadata so cleanup task can find the document
if notification and notification.notification_metadata is not None:
notification.notification_metadata["document_id"] = document_id
from sqlalchemy.orm.attributes import flag_modified
# Store document_id in notification metadata so cleanup task can find the document
if notification.notification_metadata is not None:
notification.notification_metadata["document_id"] = document_id
from sqlalchemy.orm.attributes import flag_modified
flag_modified(notification, "notification_metadata")
await session.commit()
await session.refresh(notification)
flag_modified(notification, "notification_metadata")
await session.commit()
await session.refresh(notification)
# Start Redis heartbeat for stale task detection
_start_heartbeat(notification.id)
heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification.id))
_start_heartbeat(notification.id)
heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification.id))
except Exception:
logger.warning(
f"[_process_file_with_document] Failed to create notification for: {filename}",
exc_info=True,
)
log_entry = await task_logger.log_task_start(
task_name="process_file_upload_with_document",
@ -956,14 +970,13 @@ async def _process_file_with_document(
# Update notification on success
if result:
await (
NotificationService.document_processing.notify_processing_completed(
if notification:
await NotificationService.document_processing.notify_processing_completed(
session=session,
notification=notification,
document_id=result.id,
chunks_count=None,
)
)
logger.info(
f"[_process_file_with_document] Successfully processed document {document_id}"
)
@ -972,13 +985,12 @@ async def _process_file_with_document(
document.status = DocumentStatus.failed("Duplicate content detected")
document.updated_at = get_current_timestamp()
await session.commit()
await (
NotificationService.document_processing.notify_processing_completed(
if notification:
await NotificationService.document_processing.notify_processing_completed(
session=session,
notification=notification,
error_message="Document already exists (duplicate)",
)
)
except Exception as e:
# Import here to avoid circular dependencies
@ -1009,12 +1021,13 @@ async def _process_file_with_document(
# Handle insufficient-credit errors with dedicated notification
if credit_error is not None:
try:
await session.refresh(notification)
await NotificationService.document_processing.notify_processing_completed(
session=session,
notification=notification,
error_message="Insufficient credits",
)
if notification:
await session.refresh(notification)
await NotificationService.document_processing.notify_processing_completed(
session=session,
notification=notification,
error_message="Insufficient credits",
)
await NotificationService.insufficient_credits.notify_insufficient_credits(
session=session,
user_id=UUID(user_id),
@ -1031,12 +1044,13 @@ async def _process_file_with_document(
else:
# Update notification on failure
try:
await session.refresh(notification)
await NotificationService.document_processing.notify_processing_completed(
session=session,
notification=notification,
error_message=str(e)[:100],
)
if notification:
await session.refresh(notification)
await NotificationService.document_processing.notify_processing_completed(
session=session,
notification=notification,
error_message=str(e)[:100],
)
except Exception as notif_error:
logger.error(
f"Failed to update notification on failure: {notif_error!s}"
@ -1053,8 +1067,10 @@ async def _process_file_with_document(
finally:
# Stop heartbeat — key deleted on success, expires on crash
heartbeat_task.cancel()
_stop_heartbeat(notification.id)
if heartbeat_task:
heartbeat_task.cancel()
if notification:
_stop_heartbeat(notification.id)
# Clean up temp file
if os.path.exists(temp_path):

View file

@ -78,3 +78,23 @@ async def test_processing_completed_failure(
assert done.title == "Failed: report.pdf"
assert done.message == "Processing failed: bad file"
assert done.notification_metadata["status"] == "failed"
async def test_processing_started_truncates_long_filename(
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
):
"""A long filename is truncated in the title but kept in metadata."""
long_name = "a" * 250
notification = await handler.notify_processing_started(
session=db_session,
user_id=db_user.id,
document_type="FILE",
document_name=long_name,
search_space_id=db_search_space.id,
)
assert len(notification.title) <= 200
assert notification.title.startswith("Processing: ")
assert notification.title.endswith("...")
assert notification.notification_metadata["document_name"] == long_name

View file

@ -65,10 +65,18 @@ async def test_index_calls_embed_and_chunk_via_to_thread(
"app.indexing_pipeline.cache.cached_indexing.embed_texts",
mock_embed,
)
# Bypass set_committed_value, which requires a real ORM instance (not MagicMock).
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.attach_chunks_to_document",
MagicMock(),
pipeline,
"_load_existing_chunks",
AsyncMock(return_value=[]),
)
async def _noop_persist(_session, doc, *_args, **_kwargs):
doc.status = DocumentStatus.ready()
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.persist_scratch_index",
_noop_persist,
)
connector_doc = make_connector_document(
@ -116,8 +124,17 @@ async def test_non_code_documents_use_hybrid_chunker(
MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]),
)
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.attach_chunks_to_document",
MagicMock(),
pipeline,
"_load_existing_chunks",
AsyncMock(return_value=[]),
)
async def _noop_persist(_session, doc, *_args, **_kwargs):
doc.status = DocumentStatus.ready()
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.persist_scratch_index",
_noop_persist,
)
connector_doc = make_connector_document(

View file

@ -0,0 +1,65 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.db import Chunk, Document, DocumentStatus
from app.indexing_pipeline.document_persistence import persist_scratch_index
pytestmark = pytest.mark.unit
def _make_document(doc_id: int = 1) -> Document:
document = MagicMock(spec=Document)
document.id = doc_id
document.content = None
document.status = DocumentStatus.processing()
return document
@pytest.mark.asyncio
async def test_persist_scratch_index_batches_commits(monkeypatch):
monkeypatch.setattr(
"app.indexing_pipeline.document_persistence.set_committed_value",
lambda *_args, **_kwargs: None,
)
session = MagicMock()
session.commit = AsyncMock()
document = _make_document()
chunks = [Chunk(content=f"c{i}", embedding=[0.1], position=i) for i in range(5)]
perf = MagicMock()
await persist_scratch_index(
session,
document,
"body",
chunks,
batch_size=2,
perf=perf,
)
assert session.commit.await_count == 5
assert document.status == DocumentStatus.ready()
@pytest.mark.asyncio
async def test_persist_scratch_index_empty_chunks(monkeypatch):
monkeypatch.setattr(
"app.indexing_pipeline.document_persistence.set_committed_value",
lambda *_args, **_kwargs: None,
)
session = MagicMock()
session.commit = AsyncMock()
document = _make_document()
perf = MagicMock()
await persist_scratch_index(
session,
document,
"body",
[],
batch_size=200,
perf=perf,
)
assert session.commit.await_count == 2
assert document.status == DocumentStatus.ready()

View file

@ -61,3 +61,21 @@ def test_completion_failure():
assert message == "Processing failed: bad"
assert status == "failed"
assert meta["processing_stage"] == "failed"
def test_started_title_truncates_long_name():
"""Very long filenames are truncated to fit the notification title column."""
long_name = "a" * 250
title = msg.started_title(long_name)
assert len(title) <= 200
assert title.startswith("Processing: ")
assert title.endswith("...")
def test_completion_truncates_long_name():
"""Completion titles truncate long document names."""
long_name = "b" * 250
title, _, _, _ = msg.completion(long_name, document_id=1)
assert len(title) <= 200
assert title.startswith("Ready: ")
assert title.endswith("...")

View file

@ -4,7 +4,7 @@ from __future__ import annotations
import pytest
from app.notifications.service.messages.text import truncate
from app.notifications.service.messages.text import format_title, truncate
pytestmark = pytest.mark.unit
@ -22,3 +22,22 @@ def test_truncate_keeps_text_at_exact_limit():
def test_truncate_appends_ellipsis_when_over_limit():
"""Text past the limit is cut to the limit and gains an ellipsis."""
assert truncate("a" * 41, 40) == "a" * 40 + "..."
def test_format_title_keeps_short_name():
"""Short names are joined to the prefix without truncation."""
assert format_title("Ready: ", "report.pdf") == "Ready: report.pdf"
def test_format_title_truncates_long_name():
"""Long names are truncated so the full title fits the DB limit."""
long_name = "a" * 250
title = format_title("Processing: ", long_name)
assert len(title) == 200
assert title.startswith("Processing: ")
assert title.endswith("...")
def test_format_title_respects_custom_max_length():
"""A custom max length caps the title."""
assert len(format_title("Go: ", "hello world", max_length=10)) == 10