Merge pull request #837 from CREDO23/test-document-creation

[Refactor] Core document creation / indexing pipeline with unit and e2e tests
This commit is contained in:
Rohan Verma 2026-02-25 12:27:41 -08:00 committed by GitHub
commit 2e99f1e853
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
38 changed files with 5535 additions and 3342 deletions

View file

@ -0,0 +1,46 @@
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import DocumentStatus, DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
async def index_uploaded_file(
markdown_content: str,
filename: str,
etl_service: str,
search_space_id: int,
user_id: str,
session: AsyncSession,
llm,
) -> None:
connector_doc = ConnectorDocument(
title=filename,
source_markdown=markdown_content,
unique_id=filename,
document_type=DocumentType.FILE,
search_space_id=search_space_id,
created_by_id=user_id,
connector_id=None,
should_summarize=True,
should_use_code_chunker=False,
fallback_summary=markdown_content[:4000],
metadata={
"FILE_NAME": filename,
"ETL_SERVICE": etl_service,
},
)
service = IndexingPipelineService(session)
documents = await service.prepare_for_indexing([connector_doc])
if not documents:
raise RuntimeError("prepare_for_indexing returned no documents")
indexed = await service.index(documents[0], connector_doc, llm)
if not DocumentStatus.is_state(indexed.status, DocumentStatus.READY):
raise RuntimeError(indexed.status.get("reason", "Indexing failed"))
indexed.content_needs_reindexing = False
await session.commit()

View file

@ -0,0 +1,25 @@
from pydantic import BaseModel, Field, field_validator
from app.db import DocumentType
class ConnectorDocument(BaseModel):
"""Canonical data transfer object produced by connector adapters and consumed by the indexing pipeline."""
title: str
source_markdown: str
unique_id: str
document_type: DocumentType
search_space_id: int = Field(gt=0)
should_summarize: bool = True
should_use_code_chunker: bool = False
fallback_summary: str | None = None
metadata: dict = {}
connector_id: int | None = None
created_by_id: str
@field_validator("title", "source_markdown", "unique_id", "created_by_id")
@classmethod
def not_empty(cls, v: str, info) -> str:
if not v.strip():
raise ValueError(f"{info.field_name} must not be empty or whitespace")
return v

View file

@ -0,0 +1,7 @@
from app.config import config
def chunk_text(text: str, use_code_chunker: bool = False) -> list[str]:
"""Chunk a text string using the configured chunker and return the chunk texts."""
chunker = config.code_chunker_instance if use_code_chunker else config.chunker_instance
return [c.text for c in chunker.chunk(text)]

View file

@ -0,0 +1,6 @@
from app.config import config
def embed_text(text: str) -> list[float]:
"""Embed a single text string using the configured embedding model."""
return config.embedding_model_instance.embed(text)

View file

@ -0,0 +1,15 @@
import hashlib
from app.indexing_pipeline.connector_document import ConnectorDocument
def compute_unique_identifier_hash(doc: ConnectorDocument) -> str:
"""Return a stable SHA-256 hash identifying a document by its source identity."""
combined = f"{doc.document_type.value}:{doc.unique_id}:{doc.search_space_id}"
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
def compute_content_hash(doc: ConnectorDocument) -> str:
"""Return a SHA-256 hash of the document's content scoped to its search space."""
combined = f"{doc.search_space_id}:{doc.source_markdown}"
return hashlib.sha256(combined.encode("utf-8")).hexdigest()

View file

@ -0,0 +1,39 @@
from datetime import UTC, datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import object_session
from sqlalchemy.orm.attributes import set_committed_value
from app.db import Document, DocumentStatus
async def rollback_and_persist_failure(
session: AsyncSession, document: Document, message: str
) -> None:
"""Roll back the current transaction and best-effort persist a failed status.
Called exclusively from except blocks must never raise, or the new exception
would chain with the original and mask it entirely.
"""
try:
await session.rollback()
except Exception:
return # Session is completely dead; nothing further we can do.
try:
await session.refresh(document)
document.updated_at = datetime.now(UTC)
document.status = DocumentStatus.failed(message)
await session.commit()
except Exception:
pass # Best-effort; document will be retried on the next sync.
def attach_chunks_to_document(document: Document, chunks: list) -> None:
"""Assign chunks to a document without triggering SQLAlchemy async lazy loading."""
set_committed_value(document, "chunks", chunks)
session = object_session(document)
if session is not None:
if document.id is not None:
for chunk in chunks:
chunk.document_id = document.id
session.add_all(chunks)

View file

@ -0,0 +1,28 @@
from app.prompts import SUMMARY_PROMPT_TEMPLATE
from app.utils.document_converters import optimize_content_for_context_window
async def summarize_document(source_markdown: str, llm, metadata: dict | None = None) -> str:
"""Generate a text summary of a document using an LLM, prefixed with metadata when provided."""
model_name = getattr(llm, "model", "gpt-3.5-turbo")
optimized_content = optimize_content_for_context_window(
source_markdown, metadata, model_name
)
summary_chain = SUMMARY_PROMPT_TEMPLATE | llm
content_with_metadata = (
f"<DOCUMENT><DOCUMENT_METADATA>\n\n{metadata}\n\n</DOCUMENT_METADATA>"
f"\n\n<DOCUMENT_CONTENT>\n\n{optimized_content}\n\n</DOCUMENT_CONTENT></DOCUMENT>"
)
summary_result = await summary_chain.ainvoke({"document": content_with_metadata})
summary_content = summary_result.content
if metadata:
metadata_parts = ["# DOCUMENT METADATA"]
for key, value in metadata.items():
if value:
metadata_parts.append(f"**{key.replace('_', ' ').title()}:** {value}")
metadata_section = "\n".join(metadata_parts)
return f"{metadata_section}\n\n# DOCUMENT SUMMARY\n\n{summary_content}"
return summary_content

View file

@ -0,0 +1,121 @@
from litellm.exceptions import (
APIConnectionError,
APIResponseValidationError,
AuthenticationError,
BadGatewayError,
BadRequestError,
InternalServerError,
NotFoundError,
PermissionDeniedError,
RateLimitError,
ServiceUnavailableError,
Timeout,
UnprocessableEntityError,
)
from sqlalchemy.exc import IntegrityError
# Tuples for use directly in except clauses.
RETRYABLE_LLM_ERRORS = (
RateLimitError,
Timeout,
ServiceUnavailableError,
BadGatewayError,
InternalServerError,
APIConnectionError,
)
PERMANENT_LLM_ERRORS = (
AuthenticationError,
PermissionDeniedError,
NotFoundError,
BadRequestError,
UnprocessableEntityError,
APIResponseValidationError,
)
# (LiteLLMEmbeddings, CohereEmbeddings, GeminiEmbeddings all normalize to RuntimeError).
EMBEDDING_ERRORS = (
RuntimeError, # local device failure or API backend normalization
OSError, # model files missing or corrupted (local backends)
MemoryError, # document too large for available RAM
)
class PipelineMessages:
RATE_LIMIT = "LLM rate limit exceeded. Will retry on next sync."
LLM_TIMEOUT = "LLM request timed out. Will retry on next sync."
LLM_UNAVAILABLE = "LLM service temporarily unavailable. Will retry on next sync."
LLM_BAD_GATEWAY = "LLM gateway error. Will retry on next sync."
LLM_SERVER_ERROR = "LLM internal server error. Will retry on next sync."
LLM_CONNECTION = "Could not reach the LLM service. Check network connectivity."
LLM_AUTH = "LLM authentication failed. Check your API key."
LLM_PERMISSION = "LLM request denied. Check your account permissions."
LLM_NOT_FOUND = "LLM model not found. Check your model configuration."
LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid."
LLM_UNPROCESSABLE = "Document exceeds the LLM context window even after optimization."
LLM_RESPONSE = "LLM returned an invalid response."
EMBEDDING_FAILED = "Embedding failed. Check your embedding model configuration or service."
EMBEDDING_MODEL = "Embedding model files are missing or corrupted."
EMBEDDING_MEMORY = "Not enough memory to embed this document."
CHUNKING_OVERFLOW = "Document structure is too deeply nested to chunk."
def safe_exception_message(exc: Exception) -> str:
try:
return str(exc)
except Exception:
return "Something went wrong during indexing. Error details could not be retrieved."
def llm_retryable_message(exc: Exception) -> str:
try:
if isinstance(exc, RateLimitError):
return PipelineMessages.RATE_LIMIT
if isinstance(exc, Timeout):
return PipelineMessages.LLM_TIMEOUT
if isinstance(exc, ServiceUnavailableError):
return PipelineMessages.LLM_UNAVAILABLE
if isinstance(exc, BadGatewayError):
return PipelineMessages.LLM_BAD_GATEWAY
if isinstance(exc, InternalServerError):
return PipelineMessages.LLM_SERVER_ERROR
if isinstance(exc, APIConnectionError):
return PipelineMessages.LLM_CONNECTION
return safe_exception_message(exc)
except Exception:
return "Something went wrong when calling the LLM."
def llm_permanent_message(exc: Exception) -> str:
try:
if isinstance(exc, AuthenticationError):
return PipelineMessages.LLM_AUTH
if isinstance(exc, PermissionDeniedError):
return PipelineMessages.LLM_PERMISSION
if isinstance(exc, NotFoundError):
return PipelineMessages.LLM_NOT_FOUND
if isinstance(exc, BadRequestError):
return PipelineMessages.LLM_BAD_REQUEST
if isinstance(exc, UnprocessableEntityError):
return PipelineMessages.LLM_UNPROCESSABLE
if isinstance(exc, APIResponseValidationError):
return PipelineMessages.LLM_RESPONSE
return safe_exception_message(exc)
except Exception:
return "Something went wrong when calling the LLM."
def embedding_message(exc: Exception) -> str:
try:
if isinstance(exc, RuntimeError):
return PipelineMessages.EMBEDDING_FAILED
if isinstance(exc, OSError):
return PipelineMessages.EMBEDDING_MODEL
if isinstance(exc, MemoryError):
return PipelineMessages.EMBEDDING_MEMORY
return safe_exception_message(exc)
except Exception:
return "Something went wrong when generating the embedding."

View file

@ -0,0 +1,237 @@
import contextlib
from datetime import UTC, datetime
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Chunk, Document, DocumentStatus
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_chunker import chunk_text
from app.indexing_pipeline.document_embedder import embed_text
from app.indexing_pipeline.document_hashing import (
compute_content_hash,
compute_unique_identifier_hash,
)
from app.indexing_pipeline.document_persistence import (
attach_chunks_to_document,
rollback_and_persist_failure,
)
from app.indexing_pipeline.document_summarizer import summarize_document
from app.indexing_pipeline.exceptions import (
EMBEDDING_ERRORS,
PERMANENT_LLM_ERRORS,
RETRYABLE_LLM_ERRORS,
IntegrityError,
PipelineMessages,
embedding_message,
llm_permanent_message,
llm_retryable_message,
safe_exception_message,
)
from app.indexing_pipeline.pipeline_logger import (
PipelineLogContext,
log_batch_aborted,
log_chunking_overflow,
log_doc_skipped_unknown,
log_document_queued,
log_document_requeued,
log_document_updated,
log_embedding_error,
log_index_started,
log_index_success,
log_permanent_llm_error,
log_race_condition,
log_retryable_llm_error,
log_unexpected_error,
)
class IndexingPipelineService:
"""Single pipeline for indexing connector documents. All connectors use this service."""
def __init__(self, session: AsyncSession) -> None:
self.session = session
async def prepare_for_indexing(
self, connector_docs: list[ConnectorDocument]
) -> list[Document]:
"""
Persist new documents and detect changes, returning only those that need indexing.
"""
documents = []
seen_hashes: set[str] = set()
batch_ctx = PipelineLogContext(
connector_id=connector_docs[0].connector_id if connector_docs else 0,
search_space_id=connector_docs[0].search_space_id if connector_docs else 0,
unique_id="batch",
)
for connector_doc in connector_docs:
ctx = PipelineLogContext(
connector_id=connector_doc.connector_id,
search_space_id=connector_doc.search_space_id,
unique_id=connector_doc.unique_id,
)
try:
unique_identifier_hash = compute_unique_identifier_hash(connector_doc)
content_hash = compute_content_hash(connector_doc)
if unique_identifier_hash in seen_hashes:
continue
seen_hashes.add(unique_identifier_hash)
result = await self.session.execute(
select(Document).filter(
Document.unique_identifier_hash == unique_identifier_hash
)
)
existing = result.scalars().first()
if existing is not None:
if existing.content_hash == content_hash:
if existing.title != connector_doc.title:
existing.title = connector_doc.title
existing.updated_at = datetime.now(UTC)
if not DocumentStatus.is_state(
existing.status, DocumentStatus.READY
):
existing.status = DocumentStatus.pending()
existing.updated_at = datetime.now(UTC)
documents.append(existing)
log_document_requeued(ctx)
continue
existing.title = connector_doc.title
existing.content_hash = content_hash
existing.source_markdown = connector_doc.source_markdown
existing.document_metadata = connector_doc.metadata
existing.updated_at = datetime.now(UTC)
existing.status = DocumentStatus.pending()
documents.append(existing)
log_document_updated(ctx)
continue
duplicate = await self.session.execute(
select(Document).filter(Document.content_hash == content_hash)
)
if duplicate.scalars().first() is not None:
continue
document = Document(
title=connector_doc.title,
document_type=connector_doc.document_type,
content="Pending...",
content_hash=content_hash,
unique_identifier_hash=unique_identifier_hash,
source_markdown=connector_doc.source_markdown,
document_metadata=connector_doc.metadata,
search_space_id=connector_doc.search_space_id,
connector_id=connector_doc.connector_id,
created_by_id=connector_doc.created_by_id,
updated_at=datetime.now(UTC),
status=DocumentStatus.pending(),
)
self.session.add(document)
documents.append(document)
log_document_queued(ctx)
except Exception as e:
log_doc_skipped_unknown(ctx, e)
try:
await self.session.commit()
return documents
except IntegrityError:
# A concurrent worker committed a document with the same content_hash
# or unique_identifier_hash between our check and our INSERT.
# The document already exists — roll back and let the next sync run handle it.
log_race_condition(batch_ctx)
await self.session.rollback()
return []
except Exception as e:
log_batch_aborted(batch_ctx, e)
await self.session.rollback()
return []
async def index(
self, document: Document, connector_doc: ConnectorDocument, llm
) -> Document:
"""
Run summarization, embedding, and chunking for a document and persist the results.
"""
ctx = PipelineLogContext(
connector_id=connector_doc.connector_id,
search_space_id=connector_doc.search_space_id,
unique_id=connector_doc.unique_id,
doc_id=document.id,
)
try:
log_index_started(ctx)
document.status = DocumentStatus.processing()
await self.session.commit()
if connector_doc.should_summarize and llm is not None:
content = await summarize_document(
connector_doc.source_markdown, llm, connector_doc.metadata
)
elif connector_doc.should_summarize and connector_doc.fallback_summary:
content = connector_doc.fallback_summary
else:
content = connector_doc.source_markdown
embedding = embed_text(content)
await self.session.execute(
delete(Chunk).where(Chunk.document_id == document.id)
)
chunks = [
Chunk(content=text, embedding=embed_text(text))
for text in chunk_text(
connector_doc.source_markdown,
use_code_chunker=connector_doc.should_use_code_chunker,
)
]
document.content = content
document.embedding = embedding
attach_chunks_to_document(document, chunks)
document.updated_at = datetime.now(UTC)
document.status = DocumentStatus.ready()
await self.session.commit()
log_index_success(ctx, chunk_count=len(chunks))
except RETRYABLE_LLM_ERRORS as e:
log_retryable_llm_error(ctx, e)
await rollback_and_persist_failure(
self.session, document, llm_retryable_message(e)
)
except PERMANENT_LLM_ERRORS as e:
log_permanent_llm_error(ctx, e)
await rollback_and_persist_failure(
self.session, document, llm_permanent_message(e)
)
except RecursionError as e:
log_chunking_overflow(ctx, e)
await rollback_and_persist_failure(
self.session, document, PipelineMessages.CHUNKING_OVERFLOW
)
except EMBEDDING_ERRORS as e:
log_embedding_error(ctx, e)
await rollback_and_persist_failure(
self.session, document, embedding_message(e)
)
except Exception as e:
log_unexpected_error(ctx, e)
await rollback_and_persist_failure(
self.session, document, safe_exception_message(e)
)
with contextlib.suppress(Exception):
await self.session.refresh(document)
return document

View file

@ -0,0 +1,118 @@
import logging
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class PipelineLogContext:
connector_id: int | None
search_space_id: int
unique_id: str # always available from ConnectorDocument
doc_id: int | None = None # set once the DB row exists (index phase only)
class LogMessages:
# prepare_for_indexing
DOCUMENT_QUEUED = "New document queued for indexing."
DOCUMENT_UPDATED = "Document content changed, re-queued for indexing."
DOCUMENT_REQUEUED = "Stuck document re-queued for indexing."
DOC_SKIPPED_UNKNOWN = "Unexpected error — document skipped."
BATCH_ABORTED = "Fatal DB error — aborting prepare batch."
RACE_CONDITION = "Concurrent worker beat us to the commit — rolling back batch."
# index
INDEX_STARTED = "Document indexing started."
INDEX_SUCCESS = "Document indexed successfully."
LLM_RETRYABLE = "Retryable LLM error — document marked failed, will retry on next sync."
LLM_PERMANENT = "Permanent LLM error — document marked failed."
EMBEDDING_FAILED = "Embedding error — document marked failed."
CHUNKING_OVERFLOW = "Chunking overflow — document marked failed."
UNEXPECTED = "Unexpected error — document marked failed."
def _format_context(ctx: PipelineLogContext) -> str:
parts = [
f"connector_id={ctx.connector_id}",
f"search_space_id={ctx.search_space_id}",
f"unique_id={ctx.unique_id}",
]
if ctx.doc_id is not None:
parts.append(f"doc_id={ctx.doc_id}")
return " ".join(parts)
def _build_message(msg: str, ctx: PipelineLogContext, **extra) -> str:
try:
parts = [msg, _format_context(ctx)]
for key, val in extra.items():
parts.append(f"{key}={val}")
return " ".join(parts)
except Exception:
return msg
def _safe_log(level_fn, msg: str, ctx: PipelineLogContext, exc_info=None, **extra) -> None:
# Logging must never raise — a broken log call inside an except block would
# chain with the original exception and mask it entirely.
try:
message = _build_message(msg, ctx, **extra)
level_fn(message, exc_info=exc_info)
except Exception:
pass
# ── prepare_for_indexing ──────────────────────────────────────────────────────
def log_document_queued(ctx: PipelineLogContext) -> None:
_safe_log(logger.info, LogMessages.DOCUMENT_QUEUED, ctx)
def log_document_updated(ctx: PipelineLogContext) -> None:
_safe_log(logger.info, LogMessages.DOCUMENT_UPDATED, ctx)
def log_document_requeued(ctx: PipelineLogContext) -> None:
_safe_log(logger.info, LogMessages.DOCUMENT_REQUEUED, ctx)
def log_doc_skipped_unknown(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.warning, LogMessages.DOC_SKIPPED_UNKNOWN, ctx, exc_info=exc, error=exc)
def log_race_condition(ctx: PipelineLogContext) -> None:
_safe_log(logger.warning, LogMessages.RACE_CONDITION, ctx)
def log_batch_aborted(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.error, LogMessages.BATCH_ABORTED, ctx, exc_info=exc, error=exc)
# ── index ─────────────────────────────────────────────────────────────────────
def log_index_started(ctx: PipelineLogContext) -> None:
_safe_log(logger.info, LogMessages.INDEX_STARTED, ctx)
def log_index_success(ctx: PipelineLogContext, chunk_count: int) -> None:
_safe_log(logger.info, LogMessages.INDEX_SUCCESS, ctx, chunk_count=chunk_count)
def log_retryable_llm_error(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.warning, LogMessages.LLM_RETRYABLE, ctx, exc_info=exc, error=exc)
def log_permanent_llm_error(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.error, LogMessages.LLM_PERMANENT, ctx, exc_info=exc, error=exc)
def log_embedding_error(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.error, LogMessages.EMBEDDING_FAILED, ctx, exc_info=exc, error=exc)
def log_chunking_overflow(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.error, LogMessages.CHUNKING_OVERFLOW, ctx, exc_info=exc, error=exc)
def log_unexpected_error(ctx: PipelineLogContext, exc: Exception) -> None:
_safe_log(logger.error, LogMessages.UNEXPECTED, ctx, exc_info=exc, error=exc)

View file

@ -18,6 +18,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config as app_config
from app.db import Document, DocumentStatus, DocumentType, Log, Notification
from app.indexing_pipeline.adapters.file_upload_adapter import index_uploaded_file
from app.services.llm_service import get_user_long_context_llm
from app.services.notification_service import NotificationService
from app.services.task_logging_service import TaskLoggingService
@ -33,7 +34,6 @@ from .base import (
check_document_by_unique_identifier,
check_duplicate_document,
get_current_timestamp,
safe_set_chunks,
)
from .markdown_processor import add_received_markdown_file_document
@ -1863,7 +1863,7 @@ async def process_file_in_background_with_document(
)
return None
# ===== STEP 3: Generate embeddings and chunks =====
# ===== STEP 3+4: Index via pipeline =====
if notification:
await NotificationService.document_processing.notify_processing_progress(
session, notification, stage="chunking"
@ -1871,57 +1871,22 @@ async def process_file_in_background_with_document(
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
if user_llm:
document_metadata = {
"file_name": filename,
"etl_service": etl_service,
"document_type": "File Document",
}
summary_content, summary_embedding = await generate_document_summary(
markdown_content, user_llm, document_metadata
)
else:
# Fallback: use truncated content as summary
summary_content = markdown_content[:4000]
from app.config import config
summary_embedding = config.embedding_model_instance.embed(summary_content)
chunks = await create_document_chunks(markdown_content)
# ===== STEP 4: Update document to READY =====
from sqlalchemy.orm.attributes import flag_modified
document.title = filename
document.content = summary_content
document.content_hash = content_hash
document.embedding = summary_embedding
document.document_metadata = {
"FILE_NAME": filename,
"ETL_SERVICE": etl_service or "UNKNOWN",
**(document.document_metadata or {}),
}
flag_modified(document, "document_metadata")
# Use safe_set_chunks to avoid async issues
safe_set_chunks(document, chunks)
document.source_markdown = markdown_content
document.content_needs_reindexing = False
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready() # Shows checkmark in UI
await session.commit()
await session.refresh(document)
await index_uploaded_file(
markdown_content=markdown_content,
filename=filename,
etl_service=etl_service,
search_space_id=search_space_id,
user_id=user_id,
session=session,
llm=user_llm,
)
await task_logger.log_task_success(
log_entry,
f"Successfully processed file: {filename}",
{
"document_id": document.id,
"content_hash": content_hash,
"file_type": etl_service,
"chunks_count": len(chunks),
},
)