batch chunk inserts in persist_scratch_index

This commit is contained in:
CREDO23 2026-06-17 14:59:24 +02:00
parent 220d9c4fbb
commit 34de6c6f87

View file

@ -1,12 +1,12 @@
import contextlib
import logging
import time
from datetime import UTC, datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import object_session
from sqlalchemy.orm.attributes import set_committed_value
from app.db import Document, DocumentStatus
from app.db import Chunk, Document, DocumentStatus
logger = logging.getLogger(__name__)
@ -22,7 +22,6 @@ async def rollback_and_persist_failure(
try:
await session.rollback()
except Exception:
# Session is completely dead; surface it but never raise.
logger.warning(
"Rollback failed; cannot persist failed status for document %s",
getattr(document, "id", "unknown"),
@ -35,8 +34,6 @@ async def rollback_and_persist_failure(
document.status = DocumentStatus.failed(message)
await session.commit()
except Exception:
# Best-effort: the document stays non-ready and is retried next sync.
# Log it so a permanently-stuck document is at least traceable.
logger.warning(
"Could not persist failed status for document %s; will retry next sync",
getattr(document, "id", "unknown"),
@ -46,12 +43,60 @@ async def rollback_and_persist_failure(
await session.rollback()
def attach_chunks_to_document(document: Document, chunks: list) -> None:
"""Assign chunks to a document without triggering SQLAlchemy async lazy loading."""
async def persist_scratch_index(
session: AsyncSession,
document: Document,
content: str,
chunks: list[Chunk],
*,
batch_size: int,
perf: logging.Logger,
) -> None:
"""Commit document content first, then chunk rows in batches, then mark ready."""
if document.id is None:
raise ValueError("document.id is required to persist chunks")
document.content = content
document.updated_at = datetime.now(UTC)
await session.commit()
t_persist = time.perf_counter()
total = len(chunks)
if total == 0:
set_committed_value(document, "chunks", [])
document.status = DocumentStatus.ready()
document.updated_at = datetime.now(UTC)
await session.commit()
return
effective_batch = total if batch_size <= 0 else batch_size
num_batches = (total + effective_batch - 1) // effective_batch
doc_id = document.id
for batch_idx, start in enumerate(range(0, total, effective_batch), start=1):
batch = chunks[start : start + effective_batch]
t_batch = time.perf_counter()
for chunk in batch:
chunk.document_id = doc_id
session.add_all(batch)
await session.commit()
perf.info(
"[indexing] chunk batch doc=%d batch=%d/%d rows=%d in %.3fs",
doc_id,
batch_idx,
num_batches,
len(batch),
time.perf_counter() - t_batch,
)
set_committed_value(document, "chunks", chunks)
session = object_session(document)
if session is not None:
if document.id is not None:
for chunk in chunks:
chunk.document_id = document.id
session.add_all(chunks)
document.status = DocumentStatus.ready()
document.updated_at = datetime.now(UTC)
await session.commit()
perf.info(
"[indexing] chunk persist doc=%d chunks=%d batches=%d in %.3fs",
doc_id,
total,
num_batches,
time.perf_counter() - t_persist,
)