mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-18 21:15:16 +02:00
Merge pull request #1508 from CREDO23/fix/indexing-batch-chunk-insert
fix(indexing): batch chunk inserts and truncate notification titles
This commit is contained in:
commit
6a45f24f98
13 changed files with 364 additions and 124 deletions
|
|
@ -959,6 +959,9 @@ class Config:
|
|||
CHUNK_RECONCILE_ENABLED = (
|
||||
os.getenv("CHUNK_RECONCILE_ENABLED", "true").strip().lower() == "true"
|
||||
)
|
||||
INDEXING_CHUNK_INSERT_BATCH_SIZE = int(
|
||||
os.getenv("INDEXING_CHUNK_INSERT_BATCH_SIZE", "200")
|
||||
)
|
||||
|
||||
# Proxy provider selection. Maps to a ProxyProvider implementation registered
|
||||
# in app/utils/proxy/registry.py. Add new vendors there and switch via this var.
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import object_session
|
||||
from sqlalchemy.orm.attributes import set_committed_value
|
||||
|
||||
from app.db import Document, DocumentStatus
|
||||
from app.db import Chunk, Document, DocumentStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -22,7 +22,6 @@ async def rollback_and_persist_failure(
|
|||
try:
|
||||
await session.rollback()
|
||||
except Exception:
|
||||
# Session is completely dead; surface it but never raise.
|
||||
logger.warning(
|
||||
"Rollback failed; cannot persist failed status for document %s",
|
||||
getattr(document, "id", "unknown"),
|
||||
|
|
@ -35,8 +34,6 @@ async def rollback_and_persist_failure(
|
|||
document.status = DocumentStatus.failed(message)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
# Best-effort: the document stays non-ready and is retried next sync.
|
||||
# Log it so a permanently-stuck document is at least traceable.
|
||||
logger.warning(
|
||||
"Could not persist failed status for document %s; will retry next sync",
|
||||
getattr(document, "id", "unknown"),
|
||||
|
|
@ -46,12 +43,60 @@ async def rollback_and_persist_failure(
|
|||
await session.rollback()
|
||||
|
||||
|
||||
def attach_chunks_to_document(document: Document, chunks: list) -> None:
|
||||
"""Assign chunks to a document without triggering SQLAlchemy async lazy loading."""
|
||||
async def persist_scratch_index(
|
||||
session: AsyncSession,
|
||||
document: Document,
|
||||
content: str,
|
||||
chunks: list[Chunk],
|
||||
*,
|
||||
batch_size: int,
|
||||
perf: logging.Logger,
|
||||
) -> None:
|
||||
"""Commit document content first, then chunk rows in batches, then mark ready."""
|
||||
if document.id is None:
|
||||
raise ValueError("document.id is required to persist chunks")
|
||||
|
||||
document.content = content
|
||||
document.updated_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
|
||||
t_persist = time.perf_counter()
|
||||
total = len(chunks)
|
||||
if total == 0:
|
||||
set_committed_value(document, "chunks", [])
|
||||
document.status = DocumentStatus.ready()
|
||||
document.updated_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
effective_batch = total if batch_size <= 0 else batch_size
|
||||
num_batches = (total + effective_batch - 1) // effective_batch
|
||||
doc_id = document.id
|
||||
|
||||
for batch_idx, start in enumerate(range(0, total, effective_batch), start=1):
|
||||
batch = chunks[start : start + effective_batch]
|
||||
t_batch = time.perf_counter()
|
||||
for chunk in batch:
|
||||
chunk.document_id = doc_id
|
||||
session.add_all(batch)
|
||||
await session.commit()
|
||||
perf.info(
|
||||
"[indexing] chunk batch doc=%d batch=%d/%d rows=%d in %.3fs",
|
||||
doc_id,
|
||||
batch_idx,
|
||||
num_batches,
|
||||
len(batch),
|
||||
time.perf_counter() - t_batch,
|
||||
)
|
||||
|
||||
set_committed_value(document, "chunks", chunks)
|
||||
session = object_session(document)
|
||||
if session is not None:
|
||||
if document.id is not None:
|
||||
for chunk in chunks:
|
||||
chunk.document_id = document.id
|
||||
session.add_all(chunks)
|
||||
document.status = DocumentStatus.ready()
|
||||
document.updated_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
perf.info(
|
||||
"[indexing] chunk persist doc=%d chunks=%d batches=%d in %.3fs",
|
||||
doc_id,
|
||||
total,
|
||||
num_batches,
|
||||
time.perf_counter() - t_persist,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from app.indexing_pipeline.document_hashing import (
|
|||
compute_unique_identifier_hash,
|
||||
)
|
||||
from app.indexing_pipeline.document_persistence import (
|
||||
attach_chunks_to_document,
|
||||
persist_scratch_index,
|
||||
rollback_and_persist_failure,
|
||||
)
|
||||
from app.indexing_pipeline.exceptions import (
|
||||
|
|
@ -387,21 +387,37 @@ class IndexingPipelineService:
|
|||
chunk_count = await self._reindex_incrementally(
|
||||
document, content, connector_doc, existing
|
||||
)
|
||||
perf.info(
|
||||
"[indexing] chunk+embed doc=%d chunks=%d in %.3fs",
|
||||
document.id,
|
||||
chunk_count,
|
||||
time.perf_counter() - t_step,
|
||||
)
|
||||
document.content = content
|
||||
document.updated_at = datetime.now(UTC)
|
||||
document.status = DocumentStatus.ready()
|
||||
await self.session.commit()
|
||||
else:
|
||||
chunk_count = await self._reindex_from_scratch(
|
||||
from app.config import config
|
||||
|
||||
chunks = await self._reindex_from_scratch(
|
||||
document, content, connector_doc
|
||||
)
|
||||
perf.info(
|
||||
"[indexing] chunk+embed doc=%d chunks=%d in %.3fs",
|
||||
document.id,
|
||||
chunk_count,
|
||||
time.perf_counter() - t_step,
|
||||
)
|
||||
|
||||
document.content = content
|
||||
document.updated_at = datetime.now(UTC)
|
||||
document.status = DocumentStatus.ready()
|
||||
await self.session.commit()
|
||||
chunk_count = len(chunks)
|
||||
perf.info(
|
||||
"[indexing] chunk+embed doc=%d chunks=%d in %.3fs",
|
||||
document.id,
|
||||
chunk_count,
|
||||
time.perf_counter() - t_step,
|
||||
)
|
||||
await persist_scratch_index(
|
||||
self.session,
|
||||
document,
|
||||
content,
|
||||
chunks,
|
||||
batch_size=config.INDEXING_CHUNK_INSERT_BATCH_SIZE,
|
||||
perf=perf,
|
||||
)
|
||||
perf.info(
|
||||
"[indexing] index TOTAL doc=%d chunks=%d in %.3fs",
|
||||
document.id,
|
||||
|
|
@ -484,8 +500,7 @@ class IndexingPipelineService:
|
|||
|
||||
async def _reindex_from_scratch(
|
||||
self, document: Document, content: str, connector_doc: ConnectorDocument
|
||||
) -> int:
|
||||
"""First index (or kill-switched re-index): cache-aware full chunk+embed."""
|
||||
) -> list[Chunk]:
|
||||
await self.session.execute(
|
||||
delete(Chunk).where(Chunk.document_id == document.id)
|
||||
)
|
||||
|
|
@ -495,13 +510,11 @@ class IndexingPipelineService:
|
|||
use_code_chunker=connector_doc.should_use_code_chunker,
|
||||
)
|
||||
|
||||
chunks = [
|
||||
document.embedding = summary_embedding
|
||||
return [
|
||||
Chunk(content=text, embedding=emb, position=i)
|
||||
for i, (text, emb) in enumerate(chunk_pairs)
|
||||
]
|
||||
document.embedding = summary_embedding
|
||||
attach_chunks_to_document(document, chunks)
|
||||
return len(chunks)
|
||||
|
||||
async def _reindex_incrementally(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
# Matches notifications.title VARCHAR(200).
|
||||
TITLE_MAX_LENGTH = 200
|
||||
|
||||
# Notifications newer than this are live-synced; older ones load via the list endpoint.
|
||||
SYNC_WINDOW_DAYS = 14
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class DocumentProcessingNotificationHandler(BaseNotificationHandler):
|
|||
) -> Notification:
|
||||
"""Open the notification when document processing is queued."""
|
||||
operation_id = msg.operation_id(document_type, document_name, search_space_id)
|
||||
title = f"Processing: {document_name}"
|
||||
title = msg.started_title(document_name)
|
||||
message = "Waiting in queue"
|
||||
|
||||
metadata = {
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ import hashlib
|
|||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from app.notifications.service.messages.text import format_title
|
||||
|
||||
|
||||
def operation_id(document_type: str, filename: str, search_space_id: int) -> str:
|
||||
"""Build a unique id for a document processing run."""
|
||||
|
|
@ -14,6 +16,11 @@ def operation_id(document_type: str, filename: str, search_space_id: int) -> str
|
|||
return f"doc_{document_type}_{search_space_id}_{timestamp}_{filename_hash}"
|
||||
|
||||
|
||||
def started_title(document_name: str) -> str:
|
||||
"""Title shown when document processing is queued."""
|
||||
return format_title("Processing: ", document_name)
|
||||
|
||||
|
||||
def progress(
|
||||
stage: str,
|
||||
stage_message: str | None = None,
|
||||
|
|
@ -44,11 +51,11 @@ def completion(
|
|||
) -> tuple[str, str, str, dict[str, Any]]:
|
||||
"""Compute the final title, message, status, and metadata for a finished run."""
|
||||
if error_message:
|
||||
title = f"Failed: {document_name}"
|
||||
title = format_title("Failed: ", document_name)
|
||||
message = f"Processing failed: {error_message}"
|
||||
status = "failed"
|
||||
else:
|
||||
title = f"Ready: {document_name}"
|
||||
title = format_title("Ready: ", document_name)
|
||||
message = "Now searchable!"
|
||||
status = "completed"
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,21 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.notifications.constants import TITLE_MAX_LENGTH
|
||||
|
||||
|
||||
def truncate(text: str, limit: int) -> str:
|
||||
"""Return ``text`` capped at ``limit`` chars, appending an ellipsis if cut."""
|
||||
return text[:limit] + "..." if len(text) > limit else text
|
||||
|
||||
|
||||
def format_title(prefix: str, text: str, *, max_length: int = TITLE_MAX_LENGTH) -> str:
|
||||
"""Build a notification title that fits ``max_length`` including ``prefix``."""
|
||||
budget = max_length - len(prefix)
|
||||
if budget <= 0:
|
||||
return prefix[:max_length]
|
||||
if len(text) <= budget:
|
||||
return f"{prefix}{text}"
|
||||
if budget <= 3:
|
||||
return f"{prefix}{text[:budget]}"
|
||||
return f"{prefix}{text[: budget - 3]}..."
|
||||
|
|
|
|||
|
|
@ -602,23 +602,29 @@ async def _process_file_upload(
|
|||
|
||||
# Create notification for document processing
|
||||
logger.info(f"[_process_file_upload] Creating notification for: {filename}")
|
||||
notification = (
|
||||
await NotificationService.document_processing.notify_processing_started(
|
||||
session=session,
|
||||
user_id=UUID(user_id),
|
||||
document_type="FILE",
|
||||
document_name=filename,
|
||||
search_space_id=search_space_id,
|
||||
file_size=file_size,
|
||||
notification = None
|
||||
heartbeat_task = None
|
||||
try:
|
||||
notification = (
|
||||
await NotificationService.document_processing.notify_processing_started(
|
||||
session=session,
|
||||
user_id=UUID(user_id),
|
||||
document_type="FILE",
|
||||
document_name=filename,
|
||||
search_space_id=search_space_id,
|
||||
file_size=file_size,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f"[_process_file_upload] Notification created with ID: {notification.id}"
|
||||
)
|
||||
_start_heartbeat(notification.id)
|
||||
heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification.id))
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"[_process_file_upload] Failed to create notification for: {filename}",
|
||||
exc_info=True,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f"[_process_file_upload] Notification created with ID: {notification.id if notification else 'None'}"
|
||||
)
|
||||
|
||||
# Start Redis heartbeat for stale task detection
|
||||
_start_heartbeat(notification.id)
|
||||
heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification.id))
|
||||
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="process_file_upload",
|
||||
|
|
@ -646,23 +652,21 @@ async def _process_file_upload(
|
|||
|
||||
# Update notification on success
|
||||
if result:
|
||||
await (
|
||||
NotificationService.document_processing.notify_processing_completed(
|
||||
if notification:
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
document_id=result.id,
|
||||
chunks_count=None,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Duplicate detected
|
||||
await (
|
||||
NotificationService.document_processing.notify_processing_completed(
|
||||
if notification:
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message="Document already exists (duplicate)",
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Import here to avoid circular dependencies
|
||||
|
|
@ -691,13 +695,13 @@ async def _process_file_upload(
|
|||
error_message = str(credit_error)
|
||||
# Create a dedicated insufficient credits notification
|
||||
try:
|
||||
# First, mark the processing notification as failed
|
||||
await session.refresh(notification)
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message="Insufficient credits",
|
||||
)
|
||||
if notification:
|
||||
await session.refresh(notification)
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message="Insufficient credits",
|
||||
)
|
||||
|
||||
# Then create a separate insufficient_credits notification for better UX
|
||||
await NotificationService.insufficient_credits.notify_insufficient_credits(
|
||||
|
|
@ -717,12 +721,13 @@ async def _process_file_upload(
|
|||
# HTTPException with page limit message but no detailed cause
|
||||
error_message = str(e.detail)
|
||||
try:
|
||||
await session.refresh(notification)
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message=error_message,
|
||||
)
|
||||
if notification:
|
||||
await session.refresh(notification)
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message=error_message,
|
||||
)
|
||||
except Exception as notif_error:
|
||||
logger.error(
|
||||
f"Failed to update notification on failure: {notif_error!s}"
|
||||
|
|
@ -731,13 +736,13 @@ async def _process_file_upload(
|
|||
error_message = str(e)[:100]
|
||||
# Update notification on failure - wrapped in try-except to ensure it doesn't fail silently
|
||||
try:
|
||||
# Refresh notification to ensure it's not stale after any rollback
|
||||
await session.refresh(notification)
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message=error_message,
|
||||
)
|
||||
if notification:
|
||||
await session.refresh(notification)
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message=error_message,
|
||||
)
|
||||
except Exception as notif_error:
|
||||
logger.error(
|
||||
f"Failed to update notification on failure: {notif_error!s}"
|
||||
|
|
@ -753,8 +758,10 @@ async def _process_file_upload(
|
|||
raise
|
||||
finally:
|
||||
# Stop heartbeat — key deleted on success, expires on crash
|
||||
heartbeat_task.cancel()
|
||||
_stop_heartbeat(notification.id)
|
||||
if heartbeat_task:
|
||||
heartbeat_task.cancel()
|
||||
if notification:
|
||||
_stop_heartbeat(notification.id)
|
||||
|
||||
|
||||
@celery_app.task(name="process_file_upload_with_document", bind=True)
|
||||
|
|
@ -894,29 +901,36 @@ async def _process_file_with_document(
|
|||
logger.info(
|
||||
f"[_process_file_with_document] Creating notification for: {filename}"
|
||||
)
|
||||
notification = (
|
||||
await NotificationService.document_processing.notify_processing_started(
|
||||
session=session,
|
||||
user_id=UUID(user_id),
|
||||
document_type="FILE",
|
||||
document_name=filename,
|
||||
search_space_id=search_space_id,
|
||||
file_size=file_size,
|
||||
notification = None
|
||||
heartbeat_task = None
|
||||
try:
|
||||
notification = (
|
||||
await NotificationService.document_processing.notify_processing_started(
|
||||
session=session,
|
||||
user_id=UUID(user_id),
|
||||
document_type="FILE",
|
||||
document_name=filename,
|
||||
search_space_id=search_space_id,
|
||||
file_size=file_size,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Store document_id in notification metadata so cleanup task can find the document
|
||||
if notification and notification.notification_metadata is not None:
|
||||
notification.notification_metadata["document_id"] = document_id
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
# Store document_id in notification metadata so cleanup task can find the document
|
||||
if notification.notification_metadata is not None:
|
||||
notification.notification_metadata["document_id"] = document_id
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
flag_modified(notification, "notification_metadata")
|
||||
await session.commit()
|
||||
await session.refresh(notification)
|
||||
flag_modified(notification, "notification_metadata")
|
||||
await session.commit()
|
||||
await session.refresh(notification)
|
||||
|
||||
# Start Redis heartbeat for stale task detection
|
||||
_start_heartbeat(notification.id)
|
||||
heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification.id))
|
||||
_start_heartbeat(notification.id)
|
||||
heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification.id))
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"[_process_file_with_document] Failed to create notification for: {filename}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="process_file_upload_with_document",
|
||||
|
|
@ -956,14 +970,13 @@ async def _process_file_with_document(
|
|||
|
||||
# Update notification on success
|
||||
if result:
|
||||
await (
|
||||
NotificationService.document_processing.notify_processing_completed(
|
||||
if notification:
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
document_id=result.id,
|
||||
chunks_count=None,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f"[_process_file_with_document] Successfully processed document {document_id}"
|
||||
)
|
||||
|
|
@ -972,13 +985,12 @@ async def _process_file_with_document(
|
|||
document.status = DocumentStatus.failed("Duplicate content detected")
|
||||
document.updated_at = get_current_timestamp()
|
||||
await session.commit()
|
||||
await (
|
||||
NotificationService.document_processing.notify_processing_completed(
|
||||
if notification:
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message="Document already exists (duplicate)",
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Import here to avoid circular dependencies
|
||||
|
|
@ -1009,12 +1021,13 @@ async def _process_file_with_document(
|
|||
# Handle insufficient-credit errors with dedicated notification
|
||||
if credit_error is not None:
|
||||
try:
|
||||
await session.refresh(notification)
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message="Insufficient credits",
|
||||
)
|
||||
if notification:
|
||||
await session.refresh(notification)
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message="Insufficient credits",
|
||||
)
|
||||
await NotificationService.insufficient_credits.notify_insufficient_credits(
|
||||
session=session,
|
||||
user_id=UUID(user_id),
|
||||
|
|
@ -1031,12 +1044,13 @@ async def _process_file_with_document(
|
|||
else:
|
||||
# Update notification on failure
|
||||
try:
|
||||
await session.refresh(notification)
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message=str(e)[:100],
|
||||
)
|
||||
if notification:
|
||||
await session.refresh(notification)
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message=str(e)[:100],
|
||||
)
|
||||
except Exception as notif_error:
|
||||
logger.error(
|
||||
f"Failed to update notification on failure: {notif_error!s}"
|
||||
|
|
@ -1053,8 +1067,10 @@ async def _process_file_with_document(
|
|||
|
||||
finally:
|
||||
# Stop heartbeat — key deleted on success, expires on crash
|
||||
heartbeat_task.cancel()
|
||||
_stop_heartbeat(notification.id)
|
||||
if heartbeat_task:
|
||||
heartbeat_task.cancel()
|
||||
if notification:
|
||||
_stop_heartbeat(notification.id)
|
||||
|
||||
# Clean up temp file
|
||||
if os.path.exists(temp_path):
|
||||
|
|
|
|||
|
|
@ -78,3 +78,23 @@ async def test_processing_completed_failure(
|
|||
assert done.title == "Failed: report.pdf"
|
||||
assert done.message == "Processing failed: bad file"
|
||||
assert done.notification_metadata["status"] == "failed"
|
||||
|
||||
|
||||
async def test_processing_started_truncates_long_filename(
|
||||
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
|
||||
):
|
||||
"""A long filename is truncated in the title but kept in metadata."""
|
||||
long_name = "a" * 250
|
||||
|
||||
notification = await handler.notify_processing_started(
|
||||
session=db_session,
|
||||
user_id=db_user.id,
|
||||
document_type="FILE",
|
||||
document_name=long_name,
|
||||
search_space_id=db_search_space.id,
|
||||
)
|
||||
|
||||
assert len(notification.title) <= 200
|
||||
assert notification.title.startswith("Processing: ")
|
||||
assert notification.title.endswith("...")
|
||||
assert notification.notification_metadata["document_name"] == long_name
|
||||
|
|
|
|||
|
|
@ -65,10 +65,18 @@ async def test_index_calls_embed_and_chunk_via_to_thread(
|
|||
"app.indexing_pipeline.cache.cached_indexing.embed_texts",
|
||||
mock_embed,
|
||||
)
|
||||
# Bypass set_committed_value, which requires a real ORM instance (not MagicMock).
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.attach_chunks_to_document",
|
||||
MagicMock(),
|
||||
pipeline,
|
||||
"_load_existing_chunks",
|
||||
AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
async def _noop_persist(_session, doc, *_args, **_kwargs):
|
||||
doc.status = DocumentStatus.ready()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.persist_scratch_index",
|
||||
_noop_persist,
|
||||
)
|
||||
|
||||
connector_doc = make_connector_document(
|
||||
|
|
@ -116,8 +124,17 @@ async def test_non_code_documents_use_hybrid_chunker(
|
|||
MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.attach_chunks_to_document",
|
||||
MagicMock(),
|
||||
pipeline,
|
||||
"_load_existing_chunks",
|
||||
AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
async def _noop_persist(_session, doc, *_args, **_kwargs):
|
||||
doc.status = DocumentStatus.ready()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.persist_scratch_index",
|
||||
_noop_persist,
|
||||
)
|
||||
|
||||
connector_doc = make_connector_document(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,65 @@
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.db import Chunk, Document, DocumentStatus
|
||||
from app.indexing_pipeline.document_persistence import persist_scratch_index
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _make_document(doc_id: int = 1) -> Document:
|
||||
document = MagicMock(spec=Document)
|
||||
document.id = doc_id
|
||||
document.content = None
|
||||
document.status = DocumentStatus.processing()
|
||||
return document
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_scratch_index_batches_commits(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.document_persistence.set_committed_value",
|
||||
lambda *_args, **_kwargs: None,
|
||||
)
|
||||
session = MagicMock()
|
||||
session.commit = AsyncMock()
|
||||
document = _make_document()
|
||||
chunks = [Chunk(content=f"c{i}", embedding=[0.1], position=i) for i in range(5)]
|
||||
perf = MagicMock()
|
||||
|
||||
await persist_scratch_index(
|
||||
session,
|
||||
document,
|
||||
"body",
|
||||
chunks,
|
||||
batch_size=2,
|
||||
perf=perf,
|
||||
)
|
||||
|
||||
assert session.commit.await_count == 5
|
||||
assert document.status == DocumentStatus.ready()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_scratch_index_empty_chunks(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"app.indexing_pipeline.document_persistence.set_committed_value",
|
||||
lambda *_args, **_kwargs: None,
|
||||
)
|
||||
session = MagicMock()
|
||||
session.commit = AsyncMock()
|
||||
document = _make_document()
|
||||
perf = MagicMock()
|
||||
|
||||
await persist_scratch_index(
|
||||
session,
|
||||
document,
|
||||
"body",
|
||||
[],
|
||||
batch_size=200,
|
||||
perf=perf,
|
||||
)
|
||||
|
||||
assert session.commit.await_count == 2
|
||||
assert document.status == DocumentStatus.ready()
|
||||
|
|
@ -61,3 +61,21 @@ def test_completion_failure():
|
|||
assert message == "Processing failed: bad"
|
||||
assert status == "failed"
|
||||
assert meta["processing_stage"] == "failed"
|
||||
|
||||
|
||||
def test_started_title_truncates_long_name():
|
||||
"""Very long filenames are truncated to fit the notification title column."""
|
||||
long_name = "a" * 250
|
||||
title = msg.started_title(long_name)
|
||||
assert len(title) <= 200
|
||||
assert title.startswith("Processing: ")
|
||||
assert title.endswith("...")
|
||||
|
||||
|
||||
def test_completion_truncates_long_name():
|
||||
"""Completion titles truncate long document names."""
|
||||
long_name = "b" * 250
|
||||
title, _, _, _ = msg.completion(long_name, document_id=1)
|
||||
assert len(title) <= 200
|
||||
assert title.startswith("Ready: ")
|
||||
assert title.endswith("...")
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
|
||||
import pytest
|
||||
|
||||
from app.notifications.service.messages.text import truncate
|
||||
from app.notifications.service.messages.text import format_title, truncate
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
|
@ -22,3 +22,22 @@ def test_truncate_keeps_text_at_exact_limit():
|
|||
def test_truncate_appends_ellipsis_when_over_limit():
|
||||
"""Text past the limit is cut to the limit and gains an ellipsis."""
|
||||
assert truncate("a" * 41, 40) == "a" * 40 + "..."
|
||||
|
||||
|
||||
def test_format_title_keeps_short_name():
|
||||
"""Short names are joined to the prefix without truncation."""
|
||||
assert format_title("Ready: ", "report.pdf") == "Ready: report.pdf"
|
||||
|
||||
|
||||
def test_format_title_truncates_long_name():
|
||||
"""Long names are truncated so the full title fits the DB limit."""
|
||||
long_name = "a" * 250
|
||||
title = format_title("Processing: ", long_name)
|
||||
assert len(title) == 200
|
||||
assert title.startswith("Processing: ")
|
||||
assert title.endswith("...")
|
||||
|
||||
|
||||
def test_format_title_respects_custom_max_length():
|
||||
"""A custom max length caps the title."""
|
||||
assert len(format_title("Go: ", "hello world", max_length=10)) == 10
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue