mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-29 02:46:25 +02:00
- Added performance logging to the `index_batch_parallel` method, capturing metrics for document indexing duration and concurrency. - Introduced timing measurements for both the overall indexing process and the parallel document gathering phase, improving observability of the indexing workflow. - Updated logging statements to provide detailed insights into the number of documents processed, indexed, and failed during the indexing operation.
454 lines
17 KiB
Python
454 lines
17 KiB
Python
import asyncio
|
|
import contextlib
|
|
import logging
|
|
import time
|
|
from collections.abc import Awaitable, Callable
|
|
from datetime import UTC, datetime
|
|
|
|
from sqlalchemy import delete, select
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.db import NATIVE_TO_LEGACY_DOCTYPE, Chunk, Document, DocumentStatus
|
|
from app.indexing_pipeline.connector_document import ConnectorDocument
|
|
from app.indexing_pipeline.document_chunker import chunk_text
|
|
from app.indexing_pipeline.document_embedder import embed_texts
|
|
from app.indexing_pipeline.document_hashing import (
|
|
compute_content_hash,
|
|
compute_identifier_hash,
|
|
compute_unique_identifier_hash,
|
|
)
|
|
from app.indexing_pipeline.document_persistence import (
|
|
attach_chunks_to_document,
|
|
rollback_and_persist_failure,
|
|
)
|
|
from app.indexing_pipeline.document_summarizer import summarize_document
|
|
from app.indexing_pipeline.exceptions import (
|
|
EMBEDDING_ERRORS,
|
|
PERMANENT_LLM_ERRORS,
|
|
RETRYABLE_LLM_ERRORS,
|
|
PipelineMessages,
|
|
embedding_message,
|
|
llm_permanent_message,
|
|
llm_retryable_message,
|
|
safe_exception_message,
|
|
)
|
|
from app.indexing_pipeline.pipeline_logger import (
|
|
PipelineLogContext,
|
|
log_batch_aborted,
|
|
log_chunking_overflow,
|
|
log_doc_skipped_unknown,
|
|
log_document_queued,
|
|
log_document_requeued,
|
|
log_document_updated,
|
|
log_embedding_error,
|
|
log_index_started,
|
|
log_index_success,
|
|
log_permanent_llm_error,
|
|
log_race_condition,
|
|
log_retryable_llm_error,
|
|
log_unexpected_error,
|
|
)
|
|
from app.utils.perf import get_perf_logger
|
|
|
|
|
|
class IndexingPipelineService:
|
|
"""Single pipeline for indexing connector documents. All connectors use this service."""
|
|
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
self.session = session
|
|
|
|
async def migrate_legacy_docs(
|
|
self, connector_docs: list[ConnectorDocument]
|
|
) -> None:
|
|
"""Migrate legacy Composio documents to their native Google type.
|
|
|
|
For each ConnectorDocument whose document_type has a Composio equivalent
|
|
in NATIVE_TO_LEGACY_DOCTYPE, look up the old document by legacy hash and
|
|
update its unique_identifier_hash and document_type so that
|
|
prepare_for_indexing() can find it under the native hash.
|
|
"""
|
|
for doc in connector_docs:
|
|
legacy_type = NATIVE_TO_LEGACY_DOCTYPE.get(doc.document_type.value)
|
|
if not legacy_type:
|
|
continue
|
|
|
|
legacy_hash = compute_identifier_hash(
|
|
legacy_type, doc.unique_id, doc.search_space_id
|
|
)
|
|
result = await self.session.execute(
|
|
select(Document).filter(
|
|
Document.unique_identifier_hash == legacy_hash
|
|
)
|
|
)
|
|
existing = result.scalars().first()
|
|
if existing is None:
|
|
continue
|
|
|
|
native_hash = compute_identifier_hash(
|
|
doc.document_type.value, doc.unique_id, doc.search_space_id
|
|
)
|
|
existing.unique_identifier_hash = native_hash
|
|
existing.document_type = doc.document_type
|
|
|
|
await self.session.commit()
|
|
|
|
async def index_batch(
|
|
self, connector_docs: list[ConnectorDocument], llm
|
|
) -> list[Document]:
|
|
"""Convenience method: prepare_for_indexing then index each document.
|
|
|
|
Indexers that need heartbeat callbacks or custom per-document logic
|
|
should call prepare_for_indexing() + index() directly instead.
|
|
"""
|
|
doc_map = {
|
|
compute_unique_identifier_hash(cd): cd for cd in connector_docs
|
|
}
|
|
documents = await self.prepare_for_indexing(connector_docs)
|
|
results: list[Document] = []
|
|
for document in documents:
|
|
connector_doc = doc_map.get(document.unique_identifier_hash)
|
|
if connector_doc is None:
|
|
continue
|
|
result = await self.index(document, connector_doc, llm)
|
|
results.append(result)
|
|
return results
|
|
|
|
async def prepare_for_indexing(
|
|
self, connector_docs: list[ConnectorDocument]
|
|
) -> list[Document]:
|
|
"""
|
|
Persist new documents and detect changes, returning only those that need indexing.
|
|
"""
|
|
perf = get_perf_logger()
|
|
t0 = time.perf_counter()
|
|
|
|
documents = []
|
|
seen_hashes: set[str] = set()
|
|
batch_ctx = PipelineLogContext(
|
|
connector_id=connector_docs[0].connector_id if connector_docs else 0,
|
|
search_space_id=connector_docs[0].search_space_id if connector_docs else 0,
|
|
unique_id="batch",
|
|
)
|
|
|
|
for connector_doc in connector_docs:
|
|
ctx = PipelineLogContext(
|
|
connector_id=connector_doc.connector_id,
|
|
search_space_id=connector_doc.search_space_id,
|
|
unique_id=connector_doc.unique_id,
|
|
)
|
|
try:
|
|
unique_identifier_hash = compute_unique_identifier_hash(connector_doc)
|
|
content_hash = compute_content_hash(connector_doc)
|
|
|
|
if unique_identifier_hash in seen_hashes:
|
|
continue
|
|
seen_hashes.add(unique_identifier_hash)
|
|
|
|
result = await self.session.execute(
|
|
select(Document).filter(
|
|
Document.unique_identifier_hash == unique_identifier_hash
|
|
)
|
|
)
|
|
existing = result.scalars().first()
|
|
|
|
if existing is not None:
|
|
if existing.content_hash == content_hash:
|
|
if existing.title != connector_doc.title:
|
|
existing.title = connector_doc.title
|
|
existing.updated_at = datetime.now(UTC)
|
|
if not DocumentStatus.is_state(
|
|
existing.status, DocumentStatus.READY
|
|
):
|
|
existing.status = DocumentStatus.pending()
|
|
existing.updated_at = datetime.now(UTC)
|
|
documents.append(existing)
|
|
log_document_requeued(ctx)
|
|
continue
|
|
|
|
existing.title = connector_doc.title
|
|
existing.content_hash = content_hash
|
|
existing.source_markdown = connector_doc.source_markdown
|
|
existing.document_metadata = connector_doc.metadata
|
|
existing.updated_at = datetime.now(UTC)
|
|
existing.status = DocumentStatus.pending()
|
|
documents.append(existing)
|
|
log_document_updated(ctx)
|
|
continue
|
|
|
|
duplicate = await self.session.execute(
|
|
select(Document).filter(Document.content_hash == content_hash)
|
|
)
|
|
if duplicate.scalars().first() is not None:
|
|
continue
|
|
|
|
document = Document(
|
|
title=connector_doc.title,
|
|
document_type=connector_doc.document_type,
|
|
content="Pending...",
|
|
content_hash=content_hash,
|
|
unique_identifier_hash=unique_identifier_hash,
|
|
source_markdown=connector_doc.source_markdown,
|
|
document_metadata=connector_doc.metadata,
|
|
search_space_id=connector_doc.search_space_id,
|
|
connector_id=connector_doc.connector_id,
|
|
created_by_id=connector_doc.created_by_id,
|
|
updated_at=datetime.now(UTC),
|
|
status=DocumentStatus.pending(),
|
|
)
|
|
self.session.add(document)
|
|
documents.append(document)
|
|
log_document_queued(ctx)
|
|
|
|
except Exception as e:
|
|
log_doc_skipped_unknown(ctx, e)
|
|
|
|
try:
|
|
await self.session.commit()
|
|
perf.info(
|
|
"[indexing] prepare_for_indexing in %.3fs input=%d output=%d",
|
|
time.perf_counter() - t0,
|
|
len(connector_docs),
|
|
len(documents),
|
|
)
|
|
return documents
|
|
except IntegrityError:
|
|
log_race_condition(batch_ctx)
|
|
await self.session.rollback()
|
|
return []
|
|
except Exception as e:
|
|
log_batch_aborted(batch_ctx, e)
|
|
await self.session.rollback()
|
|
return []
|
|
|
|
async def index(
|
|
self, document: Document, connector_doc: ConnectorDocument, llm
|
|
) -> Document:
|
|
"""
|
|
Run summarization, embedding, and chunking for a document and persist the results.
|
|
"""
|
|
ctx = PipelineLogContext(
|
|
connector_id=connector_doc.connector_id,
|
|
search_space_id=connector_doc.search_space_id,
|
|
unique_id=connector_doc.unique_id,
|
|
doc_id=document.id,
|
|
)
|
|
perf = get_perf_logger()
|
|
t_index = time.perf_counter()
|
|
try:
|
|
log_index_started(ctx)
|
|
document.status = DocumentStatus.processing()
|
|
await self.session.commit()
|
|
|
|
t_step = time.perf_counter()
|
|
if connector_doc.should_summarize and llm is not None:
|
|
content = await summarize_document(
|
|
connector_doc.source_markdown, llm, connector_doc.metadata
|
|
)
|
|
perf.info(
|
|
"[indexing] summarize_document doc=%d in %.3fs",
|
|
document.id,
|
|
time.perf_counter() - t_step,
|
|
)
|
|
elif connector_doc.should_summarize and connector_doc.fallback_summary:
|
|
content = connector_doc.fallback_summary
|
|
else:
|
|
content = connector_doc.source_markdown
|
|
|
|
await self.session.execute(
|
|
delete(Chunk).where(Chunk.document_id == document.id)
|
|
)
|
|
|
|
t_step = time.perf_counter()
|
|
chunk_texts = await asyncio.to_thread(
|
|
chunk_text,
|
|
connector_doc.source_markdown,
|
|
use_code_chunker=connector_doc.should_use_code_chunker,
|
|
)
|
|
|
|
texts_to_embed = [content, *chunk_texts]
|
|
embeddings = await asyncio.to_thread(embed_texts, texts_to_embed)
|
|
summary_embedding, *chunk_embeddings = embeddings
|
|
|
|
chunks = [
|
|
Chunk(content=text, embedding=emb)
|
|
for text, emb in zip(chunk_texts, chunk_embeddings, strict=False)
|
|
]
|
|
perf.info(
|
|
"[indexing] chunk+embed doc=%d chunks=%d in %.3fs",
|
|
document.id,
|
|
len(chunks),
|
|
time.perf_counter() - t_step,
|
|
)
|
|
|
|
document.content = content
|
|
document.embedding = summary_embedding
|
|
attach_chunks_to_document(document, chunks)
|
|
document.updated_at = datetime.now(UTC)
|
|
document.status = DocumentStatus.ready()
|
|
await self.session.commit()
|
|
perf.info(
|
|
"[indexing] index TOTAL doc=%d chunks=%d in %.3fs",
|
|
document.id,
|
|
len(chunks),
|
|
time.perf_counter() - t_index,
|
|
)
|
|
log_index_success(ctx, chunk_count=len(chunks))
|
|
|
|
except RETRYABLE_LLM_ERRORS as e:
|
|
log_retryable_llm_error(ctx, e)
|
|
await rollback_and_persist_failure(
|
|
self.session, document, llm_retryable_message(e)
|
|
)
|
|
|
|
except PERMANENT_LLM_ERRORS as e:
|
|
log_permanent_llm_error(ctx, e)
|
|
await rollback_and_persist_failure(
|
|
self.session, document, llm_permanent_message(e)
|
|
)
|
|
|
|
except RecursionError as e:
|
|
log_chunking_overflow(ctx, e)
|
|
await rollback_and_persist_failure(
|
|
self.session, document, PipelineMessages.CHUNKING_OVERFLOW
|
|
)
|
|
|
|
except EMBEDDING_ERRORS as e:
|
|
log_embedding_error(ctx, e)
|
|
await rollback_and_persist_failure(
|
|
self.session, document, embedding_message(e)
|
|
)
|
|
|
|
except Exception as e:
|
|
log_unexpected_error(ctx, e)
|
|
await rollback_and_persist_failure(
|
|
self.session, document, safe_exception_message(e)
|
|
)
|
|
|
|
with contextlib.suppress(Exception):
|
|
await self.session.refresh(document)
|
|
|
|
return document
|
|
|
|
async def index_batch_parallel(
|
|
self,
|
|
connector_docs: list[ConnectorDocument],
|
|
get_llm: Callable[[AsyncSession], Awaitable],
|
|
*,
|
|
max_concurrency: int = 4,
|
|
on_heartbeat: Callable[[int], Awaitable[None]] | None = None,
|
|
heartbeat_interval: float = 30.0,
|
|
) -> tuple[list[Document], int, int]:
|
|
"""Index documents in parallel with bounded concurrency.
|
|
|
|
Phase 1 (serial): prepare_for_indexing using self.session.
|
|
Phase 2 (parallel): index each document in an isolated session,
|
|
bounded by a semaphore to avoid overwhelming APIs/DB.
|
|
"""
|
|
logger = logging.getLogger(__name__)
|
|
perf = get_perf_logger()
|
|
t_total = time.perf_counter()
|
|
|
|
doc_map = {
|
|
compute_unique_identifier_hash(cd): cd for cd in connector_docs
|
|
}
|
|
documents = await self.prepare_for_indexing(connector_docs)
|
|
|
|
if not documents:
|
|
return [], 0, 0
|
|
|
|
from app.tasks.celery_tasks import get_celery_session_maker
|
|
|
|
sem = asyncio.Semaphore(max_concurrency)
|
|
lock = asyncio.Lock()
|
|
indexed_count = 0
|
|
failed_count = 0
|
|
results: list[Document] = []
|
|
last_heartbeat = time.time()
|
|
|
|
async def _index_one(document: Document) -> Document | Exception:
|
|
nonlocal indexed_count, failed_count, last_heartbeat
|
|
|
|
connector_doc = doc_map.get(document.unique_identifier_hash)
|
|
if connector_doc is None:
|
|
logger.warning(
|
|
"No matching ConnectorDocument for document %s, skipping",
|
|
document.id,
|
|
)
|
|
async with lock:
|
|
failed_count += 1
|
|
return document
|
|
|
|
async with sem:
|
|
session_maker = get_celery_session_maker()
|
|
async with session_maker() as isolated_session:
|
|
try:
|
|
refetched = await isolated_session.get(
|
|
Document, document.id
|
|
)
|
|
if refetched is None:
|
|
async with lock:
|
|
failed_count += 1
|
|
return document
|
|
|
|
llm = await get_llm(isolated_session)
|
|
iso_pipeline = IndexingPipelineService(isolated_session)
|
|
result = await iso_pipeline.index(
|
|
refetched, connector_doc, llm
|
|
)
|
|
|
|
async with lock:
|
|
if DocumentStatus.is_state(
|
|
result.status, DocumentStatus.READY
|
|
):
|
|
indexed_count += 1
|
|
else:
|
|
failed_count += 1
|
|
|
|
if on_heartbeat:
|
|
now = time.time()
|
|
if now - last_heartbeat >= heartbeat_interval:
|
|
await on_heartbeat(indexed_count)
|
|
last_heartbeat = now
|
|
|
|
return result
|
|
except Exception as exc:
|
|
logger.error(
|
|
"Parallel index failed for doc %s: %s",
|
|
document.id,
|
|
exc,
|
|
exc_info=True,
|
|
)
|
|
async with lock:
|
|
failed_count += 1
|
|
return exc
|
|
|
|
tasks = [_index_one(doc) for doc in documents]
|
|
t_parallel = time.perf_counter()
|
|
outcomes = await asyncio.gather(*tasks, return_exceptions=True)
|
|
perf.info(
|
|
"[indexing] index_batch_parallel gather docs=%d concurrency=%d "
|
|
"indexed=%d failed=%d in %.3fs",
|
|
len(documents),
|
|
max_concurrency,
|
|
indexed_count,
|
|
failed_count,
|
|
time.perf_counter() - t_parallel,
|
|
)
|
|
|
|
for outcome in outcomes:
|
|
if isinstance(outcome, Document):
|
|
results.append(outcome)
|
|
elif isinstance(outcome, Exception):
|
|
pass
|
|
|
|
perf.info(
|
|
"[indexing] index_batch_parallel TOTAL input=%d prepared=%d "
|
|
"indexed=%d failed=%d in %.3fs",
|
|
len(connector_docs),
|
|
len(documents),
|
|
indexed_count,
|
|
failed_count,
|
|
time.perf_counter() - t_total,
|
|
)
|
|
return results, indexed_count, failed_count
|