diff --git a/surfsense_backend/app/indexing_pipeline/document_persistence.py b/surfsense_backend/app/indexing_pipeline/document_persistence.py new file mode 100644 index 000000000..f9be92fe6 --- /dev/null +++ b/surfsense_backend/app/indexing_pipeline/document_persistence.py @@ -0,0 +1,29 @@ +from datetime import UTC, datetime + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import object_session +from sqlalchemy.orm.attributes import set_committed_value + +from app.db import Document, DocumentStatus + + +async def rollback_and_persist_failure( + session: AsyncSession, document: Document, message: str +) -> None: + """Roll back the current transaction, refresh the document, and persist a failed status.""" + await session.rollback() + await session.refresh(document) + document.updated_at = datetime.now(UTC) + document.status = DocumentStatus.failed(message) + await session.commit() + + +def attach_chunks_to_document(document: Document, chunks: list) -> None: + """Assign chunks to a document without triggering SQLAlchemy async lazy loading.""" + set_committed_value(document, "chunks", chunks) + session = object_session(document) + if session is not None: + if document.id is not None: + for chunk in chunks: + chunk.document_id = document.id + session.add_all(chunks) diff --git a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py index 0e114d998..3a9a1de5c 100644 --- a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py +++ b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py @@ -2,14 +2,13 @@ from datetime import UTC, datetime from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import object_session -from sqlalchemy.orm.attributes import set_committed_value from app.db import Chunk, Document, DocumentStatus from app.indexing_pipeline.connector_document import ConnectorDocument from app.indexing_pipeline.document_chunker import chunk_text from app.indexing_pipeline.document_embedder import embed_text from app.indexing_pipeline.document_hashing import compute_content_hash, compute_unique_identifier_hash +from app.indexing_pipeline.document_persistence import attach_chunks_to_document, rollback_and_persist_failure from app.indexing_pipeline.document_summarizer import summarize_document from app.indexing_pipeline.exceptions import ( EMBEDDING_ERRORS, @@ -25,26 +24,6 @@ from app.indexing_pipeline.exceptions import ( ) -async def _mark_failed(session: AsyncSession, document: Document, message: str) -> None: - """Roll back the current transaction, refresh the document, and persist a failed status.""" - await session.rollback() - await session.refresh(document) - document.updated_at = datetime.now(UTC) - document.status = DocumentStatus.failed(message) - await session.commit() - - -def _safe_set_chunks(document: Document, chunks: list) -> None: - """Assign chunks to a document without triggering SQLAlchemy async lazy loading.""" - set_committed_value(document, "chunks", chunks) - session = object_session(document) - if session is not None: - if document.id is not None: - for chunk in chunks: - chunk.document_id = document.id - session.add_all(chunks) - - class IndexingPipelineService: """Single pipeline for indexing connector documents. All connectors use this service.""" @@ -171,28 +150,28 @@ class IndexingPipelineService: document.content = content document.embedding = embedding - _safe_set_chunks(document, chunks) + attach_chunks_to_document(document, chunks) document.updated_at = datetime.now(UTC) document.status = DocumentStatus.ready() await self.session.commit() except RETRYABLE_LLM_ERRORS as e: - await _mark_failed(self.session, document, llm_retryable_message(e)) + await rollback_and_persist_failure(self.session, document, llm_retryable_message(e)) except PERMANENT_LLM_ERRORS as e: - await _mark_failed(self.session, document, llm_permanent_message(e)) + await rollback_and_persist_failure(self.session, document, llm_permanent_message(e)) except EMBEDDING_ERRORS as e: - await _mark_failed(self.session, document, embedding_message(e)) + await rollback_and_persist_failure(self.session, document, embedding_message(e)) except RecursionError: - await _mark_failed(self.session, document, PipelineMessages.CHUNKING_OVERFLOW) + await rollback_and_persist_failure(self.session, document, PipelineMessages.CHUNKING_OVERFLOW) except FATAL_DB_ERRORS: raise except TRANSIENT_DB_ERRORS: - await _mark_failed(self.session, document, PipelineMessages.DB_TRANSIENT) + await rollback_and_persist_failure(self.session, document, PipelineMessages.DB_TRANSIENT) except Exception as e: - await _mark_failed(self.session, document, str(e)) + await rollback_and_persist_failure(self.session, document, str(e))