mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-26 17:26:23 +02:00
600 lines
22 KiB
Python
600 lines
22 KiB
Python
import asyncio
|
|
import contextlib
|
|
import hashlib
|
|
import logging
|
|
import time
|
|
from collections.abc import Awaitable, Callable
|
|
from dataclasses import dataclass, field
|
|
from datetime import UTC, datetime
|
|
|
|
from sqlalchemy import delete, select
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.db import (
|
|
NATIVE_TO_LEGACY_DOCTYPE,
|
|
Chunk,
|
|
Document,
|
|
DocumentStatus,
|
|
DocumentType,
|
|
)
|
|
from app.indexing_pipeline.connector_document import ConnectorDocument
|
|
from app.indexing_pipeline.document_chunker import chunk_text
|
|
from app.indexing_pipeline.document_embedder import embed_texts
|
|
from app.indexing_pipeline.document_hashing import (
|
|
compute_content_hash,
|
|
compute_identifier_hash,
|
|
compute_unique_identifier_hash,
|
|
)
|
|
from app.indexing_pipeline.document_persistence import (
|
|
attach_chunks_to_document,
|
|
rollback_and_persist_failure,
|
|
)
|
|
from app.indexing_pipeline.document_summarizer import summarize_document
|
|
from app.indexing_pipeline.exceptions import (
|
|
EMBEDDING_ERRORS,
|
|
PERMANENT_LLM_ERRORS,
|
|
RETRYABLE_LLM_ERRORS,
|
|
PipelineMessages,
|
|
embedding_message,
|
|
llm_permanent_message,
|
|
llm_retryable_message,
|
|
safe_exception_message,
|
|
)
|
|
from app.indexing_pipeline.pipeline_logger import (
|
|
PipelineLogContext,
|
|
log_batch_aborted,
|
|
log_chunking_overflow,
|
|
log_doc_skipped_unknown,
|
|
log_document_queued,
|
|
log_document_requeued,
|
|
log_document_updated,
|
|
log_embedding_error,
|
|
log_index_started,
|
|
log_index_success,
|
|
log_permanent_llm_error,
|
|
log_race_condition,
|
|
log_retryable_llm_error,
|
|
log_unexpected_error,
|
|
)
|
|
from app.utils.perf import get_perf_logger
|
|
|
|
|
|
@dataclass
|
|
class PlaceholderInfo:
|
|
"""Minimal info to create a placeholder document row for instant UI feedback.
|
|
|
|
These are created immediately when items are discovered (before content
|
|
extraction) so users see them in the UI via Zero sync right away.
|
|
"""
|
|
|
|
title: str
|
|
document_type: DocumentType
|
|
unique_id: str
|
|
search_space_id: int
|
|
connector_id: int | None
|
|
created_by_id: str
|
|
metadata: dict = field(default_factory=dict)
|
|
|
|
|
|
class IndexingPipelineService:
|
|
"""Single pipeline for indexing connector documents. All connectors use this service."""
|
|
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
self.session = session
|
|
|
|
async def create_placeholder_documents(
|
|
self, placeholders: list[PlaceholderInfo]
|
|
) -> int:
|
|
"""Create placeholder document rows with pending status for instant UI feedback.
|
|
|
|
These rows appear immediately in the UI via Zero sync. They are later
|
|
updated by prepare_for_indexing() when actual content is available.
|
|
|
|
Returns the number of placeholders successfully created.
|
|
Failures are logged but never block the main indexing flow.
|
|
|
|
NOTE: This method commits on ``self.session`` so the rows become
|
|
visible to Zero sync immediately. Any pending ORM mutations on the
|
|
session are committed together, which is consistent with how other
|
|
mid-flow commits work in the indexing codebase (e.g. rename-only
|
|
updates in ``_should_skip_file``, ``migrate_legacy_docs``).
|
|
"""
|
|
if not placeholders:
|
|
return 0
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
uid_hashes: dict[str, PlaceholderInfo] = {}
|
|
for p in placeholders:
|
|
try:
|
|
uid_hash = compute_identifier_hash(
|
|
p.document_type.value, p.unique_id, p.search_space_id
|
|
)
|
|
uid_hashes.setdefault(uid_hash, p)
|
|
except Exception:
|
|
_logger.debug(
|
|
"Skipping placeholder hash for %s", p.unique_id, exc_info=True
|
|
)
|
|
|
|
if not uid_hashes:
|
|
return 0
|
|
|
|
result = await self.session.execute(
|
|
select(Document.unique_identifier_hash).where(
|
|
Document.unique_identifier_hash.in_(list(uid_hashes.keys()))
|
|
)
|
|
)
|
|
existing_hashes: set[str] = set(result.scalars().all())
|
|
|
|
created = 0
|
|
for uid_hash, p in uid_hashes.items():
|
|
if uid_hash in existing_hashes:
|
|
continue
|
|
try:
|
|
content_hash = hashlib.sha256(
|
|
f"placeholder:{uid_hash}".encode()
|
|
).hexdigest()
|
|
|
|
document = Document(
|
|
title=p.title,
|
|
document_type=p.document_type,
|
|
content="Pending...",
|
|
content_hash=content_hash,
|
|
unique_identifier_hash=uid_hash,
|
|
document_metadata=p.metadata or {},
|
|
search_space_id=p.search_space_id,
|
|
connector_id=p.connector_id,
|
|
created_by_id=p.created_by_id,
|
|
updated_at=datetime.now(UTC),
|
|
status=DocumentStatus.pending(),
|
|
)
|
|
self.session.add(document)
|
|
created += 1
|
|
except Exception:
|
|
_logger.debug("Skipping placeholder for %s", p.unique_id, exc_info=True)
|
|
|
|
if created > 0:
|
|
try:
|
|
await self.session.commit()
|
|
_logger.info(
|
|
"Created %d placeholder document(s) for instant UI feedback",
|
|
created,
|
|
)
|
|
except IntegrityError:
|
|
await self.session.rollback()
|
|
_logger.debug("Placeholder commit failed (race condition), continuing")
|
|
created = 0
|
|
|
|
return created
|
|
|
|
async def migrate_legacy_docs(
|
|
self, connector_docs: list[ConnectorDocument]
|
|
) -> None:
|
|
"""Migrate legacy Composio documents to their native Google type.
|
|
|
|
For each ConnectorDocument whose document_type has a Composio equivalent
|
|
in NATIVE_TO_LEGACY_DOCTYPE, look up the old document by legacy hash and
|
|
update its unique_identifier_hash and document_type so that
|
|
prepare_for_indexing() can find it under the native hash.
|
|
"""
|
|
for doc in connector_docs:
|
|
legacy_type = NATIVE_TO_LEGACY_DOCTYPE.get(doc.document_type.value)
|
|
if not legacy_type:
|
|
continue
|
|
|
|
legacy_hash = compute_identifier_hash(
|
|
legacy_type, doc.unique_id, doc.search_space_id
|
|
)
|
|
result = await self.session.execute(
|
|
select(Document).filter(Document.unique_identifier_hash == legacy_hash)
|
|
)
|
|
existing = result.scalars().first()
|
|
if existing is None:
|
|
continue
|
|
|
|
native_hash = compute_identifier_hash(
|
|
doc.document_type.value, doc.unique_id, doc.search_space_id
|
|
)
|
|
existing.unique_identifier_hash = native_hash
|
|
existing.document_type = doc.document_type
|
|
|
|
await self.session.commit()
|
|
|
|
async def index_batch(
|
|
self, connector_docs: list[ConnectorDocument], llm
|
|
) -> list[Document]:
|
|
"""Convenience method: prepare_for_indexing then index each document.
|
|
|
|
Indexers that need heartbeat callbacks or custom per-document logic
|
|
should call prepare_for_indexing() + index() directly instead.
|
|
"""
|
|
doc_map = {compute_unique_identifier_hash(cd): cd for cd in connector_docs}
|
|
documents = await self.prepare_for_indexing(connector_docs)
|
|
results: list[Document] = []
|
|
for document in documents:
|
|
connector_doc = doc_map.get(document.unique_identifier_hash)
|
|
if connector_doc is None:
|
|
continue
|
|
result = await self.index(document, connector_doc, llm)
|
|
results.append(result)
|
|
return results
|
|
|
|
async def prepare_for_indexing(
|
|
self, connector_docs: list[ConnectorDocument]
|
|
) -> list[Document]:
|
|
"""
|
|
Persist new documents and detect changes, returning only those that need indexing.
|
|
"""
|
|
perf = get_perf_logger()
|
|
t0 = time.perf_counter()
|
|
|
|
documents = []
|
|
seen_hashes: set[str] = set()
|
|
batch_ctx = PipelineLogContext(
|
|
connector_id=connector_docs[0].connector_id if connector_docs else 0,
|
|
search_space_id=connector_docs[0].search_space_id if connector_docs else 0,
|
|
unique_id="batch",
|
|
)
|
|
|
|
for connector_doc in connector_docs:
|
|
ctx = PipelineLogContext(
|
|
connector_id=connector_doc.connector_id,
|
|
search_space_id=connector_doc.search_space_id,
|
|
unique_id=connector_doc.unique_id,
|
|
)
|
|
try:
|
|
unique_identifier_hash = compute_unique_identifier_hash(connector_doc)
|
|
content_hash = compute_content_hash(connector_doc)
|
|
|
|
if unique_identifier_hash in seen_hashes:
|
|
continue
|
|
seen_hashes.add(unique_identifier_hash)
|
|
|
|
result = await self.session.execute(
|
|
select(Document).filter(
|
|
Document.unique_identifier_hash == unique_identifier_hash
|
|
)
|
|
)
|
|
existing = result.scalars().first()
|
|
|
|
if existing is not None:
|
|
if existing.content_hash == content_hash:
|
|
if existing.title != connector_doc.title:
|
|
existing.title = connector_doc.title
|
|
existing.updated_at = datetime.now(UTC)
|
|
if not DocumentStatus.is_state(
|
|
existing.status, DocumentStatus.READY
|
|
):
|
|
existing.status = DocumentStatus.pending()
|
|
existing.updated_at = datetime.now(UTC)
|
|
if connector_doc.folder_id is not None:
|
|
existing.folder_id = connector_doc.folder_id
|
|
documents.append(existing)
|
|
log_document_requeued(ctx)
|
|
continue
|
|
|
|
dup_check = await self.session.execute(
|
|
select(Document.id, Document.title).filter(
|
|
Document.content_hash == content_hash,
|
|
Document.id != existing.id,
|
|
)
|
|
)
|
|
dup_row = dup_check.first()
|
|
if dup_row is not None:
|
|
if not DocumentStatus.is_state(
|
|
existing.status, DocumentStatus.READY
|
|
):
|
|
existing.status = DocumentStatus.failed(
|
|
f"Duplicate content: matches '{dup_row.title}'"
|
|
)
|
|
continue
|
|
|
|
existing.title = connector_doc.title
|
|
existing.content_hash = content_hash
|
|
existing.source_markdown = connector_doc.source_markdown
|
|
existing.document_metadata = connector_doc.metadata
|
|
existing.updated_at = datetime.now(UTC)
|
|
existing.status = DocumentStatus.pending()
|
|
if connector_doc.folder_id is not None:
|
|
existing.folder_id = connector_doc.folder_id
|
|
documents.append(existing)
|
|
log_document_updated(ctx)
|
|
continue
|
|
|
|
duplicate = await self.session.execute(
|
|
select(Document).filter(Document.content_hash == content_hash)
|
|
)
|
|
if duplicate.scalars().first() is not None:
|
|
continue
|
|
|
|
document = Document(
|
|
title=connector_doc.title,
|
|
document_type=connector_doc.document_type,
|
|
content="Pending...",
|
|
content_hash=content_hash,
|
|
unique_identifier_hash=unique_identifier_hash,
|
|
source_markdown=connector_doc.source_markdown,
|
|
document_metadata=connector_doc.metadata,
|
|
search_space_id=connector_doc.search_space_id,
|
|
connector_id=connector_doc.connector_id,
|
|
created_by_id=connector_doc.created_by_id,
|
|
updated_at=datetime.now(UTC),
|
|
status=DocumentStatus.pending(),
|
|
folder_id=connector_doc.folder_id,
|
|
)
|
|
self.session.add(document)
|
|
documents.append(document)
|
|
log_document_queued(ctx)
|
|
|
|
except Exception as e:
|
|
log_doc_skipped_unknown(ctx, e)
|
|
|
|
try:
|
|
await self.session.commit()
|
|
perf.info(
|
|
"[indexing] prepare_for_indexing in %.3fs input=%d output=%d",
|
|
time.perf_counter() - t0,
|
|
len(connector_docs),
|
|
len(documents),
|
|
)
|
|
return documents
|
|
except IntegrityError:
|
|
log_race_condition(batch_ctx)
|
|
await self.session.rollback()
|
|
return []
|
|
except Exception as e:
|
|
log_batch_aborted(batch_ctx, e)
|
|
await self.session.rollback()
|
|
return []
|
|
|
|
async def index(
|
|
self, document: Document, connector_doc: ConnectorDocument, llm
|
|
) -> Document:
|
|
"""
|
|
Run summarization, embedding, and chunking for a document and persist the results.
|
|
"""
|
|
ctx = PipelineLogContext(
|
|
connector_id=connector_doc.connector_id,
|
|
search_space_id=connector_doc.search_space_id,
|
|
unique_id=connector_doc.unique_id,
|
|
doc_id=document.id,
|
|
)
|
|
perf = get_perf_logger()
|
|
t_index = time.perf_counter()
|
|
try:
|
|
log_index_started(ctx)
|
|
document.status = DocumentStatus.processing()
|
|
await self.session.commit()
|
|
|
|
t_step = time.perf_counter()
|
|
if connector_doc.should_summarize and llm is not None:
|
|
content = await summarize_document(
|
|
connector_doc.source_markdown, llm, connector_doc.metadata
|
|
)
|
|
perf.info(
|
|
"[indexing] summarize_document doc=%d in %.3fs",
|
|
document.id,
|
|
time.perf_counter() - t_step,
|
|
)
|
|
elif connector_doc.should_summarize and connector_doc.fallback_summary:
|
|
content = connector_doc.fallback_summary
|
|
else:
|
|
content = connector_doc.source_markdown
|
|
|
|
await self.session.execute(
|
|
delete(Chunk).where(Chunk.document_id == document.id)
|
|
)
|
|
|
|
t_step = time.perf_counter()
|
|
chunk_texts = await asyncio.to_thread(
|
|
chunk_text,
|
|
connector_doc.source_markdown,
|
|
use_code_chunker=connector_doc.should_use_code_chunker,
|
|
)
|
|
|
|
texts_to_embed = [content, *chunk_texts]
|
|
embeddings = await asyncio.to_thread(embed_texts, texts_to_embed)
|
|
summary_embedding, *chunk_embeddings = embeddings
|
|
|
|
chunks = [
|
|
Chunk(content=text, embedding=emb)
|
|
for text, emb in zip(chunk_texts, chunk_embeddings, strict=False)
|
|
]
|
|
perf.info(
|
|
"[indexing] chunk+embed doc=%d chunks=%d in %.3fs",
|
|
document.id,
|
|
len(chunks),
|
|
time.perf_counter() - t_step,
|
|
)
|
|
|
|
document.content = content
|
|
document.embedding = summary_embedding
|
|
attach_chunks_to_document(document, chunks)
|
|
document.updated_at = datetime.now(UTC)
|
|
document.status = DocumentStatus.ready()
|
|
await self.session.commit()
|
|
perf.info(
|
|
"[indexing] index TOTAL doc=%d chunks=%d in %.3fs",
|
|
document.id,
|
|
len(chunks),
|
|
time.perf_counter() - t_index,
|
|
)
|
|
log_index_success(ctx, chunk_count=len(chunks))
|
|
|
|
await self._enqueue_ai_sort_if_enabled(document)
|
|
|
|
except RETRYABLE_LLM_ERRORS as e:
|
|
log_retryable_llm_error(ctx, e)
|
|
await rollback_and_persist_failure(
|
|
self.session, document, llm_retryable_message(e)
|
|
)
|
|
|
|
except PERMANENT_LLM_ERRORS as e:
|
|
log_permanent_llm_error(ctx, e)
|
|
await rollback_and_persist_failure(
|
|
self.session, document, llm_permanent_message(e)
|
|
)
|
|
|
|
except RecursionError as e:
|
|
log_chunking_overflow(ctx, e)
|
|
await rollback_and_persist_failure(
|
|
self.session, document, PipelineMessages.CHUNKING_OVERFLOW
|
|
)
|
|
|
|
except EMBEDDING_ERRORS as e:
|
|
log_embedding_error(ctx, e)
|
|
await rollback_and_persist_failure(
|
|
self.session, document, embedding_message(e)
|
|
)
|
|
|
|
except Exception as e:
|
|
log_unexpected_error(ctx, e)
|
|
await rollback_and_persist_failure(
|
|
self.session, document, safe_exception_message(e)
|
|
)
|
|
|
|
with contextlib.suppress(Exception):
|
|
await self.session.refresh(document)
|
|
|
|
return document
|
|
|
|
async def _enqueue_ai_sort_if_enabled(self, document: Document) -> None:
|
|
"""Fire-and-forget: enqueue incremental AI sort if the search space has it enabled."""
|
|
try:
|
|
from app.db import SearchSpace
|
|
|
|
result = await self.session.execute(
|
|
select(SearchSpace.ai_file_sort_enabled).where(
|
|
SearchSpace.id == document.search_space_id
|
|
)
|
|
)
|
|
enabled = result.scalar()
|
|
if not enabled:
|
|
return
|
|
|
|
from app.tasks.celery_tasks.document_tasks import ai_sort_document_task
|
|
|
|
user_id = str(document.created_by_id) if document.created_by_id else ""
|
|
ai_sort_document_task.delay(document.search_space_id, user_id, document.id)
|
|
except Exception:
|
|
logging.getLogger(__name__).warning(
|
|
"Failed to enqueue AI sort for document %s", document.id, exc_info=True
|
|
)
|
|
|
|
async def index_batch_parallel(
|
|
self,
|
|
connector_docs: list[ConnectorDocument],
|
|
get_llm: Callable[[AsyncSession], Awaitable],
|
|
*,
|
|
max_concurrency: int = 4,
|
|
on_heartbeat: Callable[[int], Awaitable[None]] | None = None,
|
|
heartbeat_interval: float = 30.0,
|
|
) -> tuple[list[Document], int, int]:
|
|
"""Index documents in parallel with bounded concurrency.
|
|
|
|
Phase 1 (serial): prepare_for_indexing using self.session.
|
|
Phase 2 (parallel): index each document in an isolated session,
|
|
bounded by a semaphore to avoid overwhelming APIs/DB.
|
|
"""
|
|
logger = logging.getLogger(__name__)
|
|
perf = get_perf_logger()
|
|
t_total = time.perf_counter()
|
|
|
|
doc_map = {compute_unique_identifier_hash(cd): cd for cd in connector_docs}
|
|
documents = await self.prepare_for_indexing(connector_docs)
|
|
|
|
if not documents:
|
|
return [], 0, 0
|
|
|
|
from app.tasks.celery_tasks import get_celery_session_maker
|
|
|
|
sem = asyncio.Semaphore(max_concurrency)
|
|
lock = asyncio.Lock()
|
|
indexed_count = 0
|
|
failed_count = 0
|
|
results: list[Document] = []
|
|
last_heartbeat = time.time()
|
|
|
|
async def _index_one(document: Document) -> Document | Exception:
|
|
nonlocal indexed_count, failed_count, last_heartbeat
|
|
|
|
connector_doc = doc_map.get(document.unique_identifier_hash)
|
|
if connector_doc is None:
|
|
logger.warning(
|
|
"No matching ConnectorDocument for document %s, skipping",
|
|
document.id,
|
|
)
|
|
async with lock:
|
|
failed_count += 1
|
|
return document
|
|
|
|
async with sem:
|
|
session_maker = get_celery_session_maker()
|
|
async with session_maker() as isolated_session:
|
|
try:
|
|
refetched = await isolated_session.get(Document, document.id)
|
|
if refetched is None:
|
|
async with lock:
|
|
failed_count += 1
|
|
return document
|
|
|
|
llm = await get_llm(isolated_session)
|
|
iso_pipeline = IndexingPipelineService(isolated_session)
|
|
result = await iso_pipeline.index(refetched, connector_doc, llm)
|
|
|
|
async with lock:
|
|
if DocumentStatus.is_state(
|
|
result.status, DocumentStatus.READY
|
|
):
|
|
indexed_count += 1
|
|
else:
|
|
failed_count += 1
|
|
|
|
if on_heartbeat:
|
|
now = time.time()
|
|
if now - last_heartbeat >= heartbeat_interval:
|
|
await on_heartbeat(indexed_count)
|
|
last_heartbeat = now
|
|
|
|
return result
|
|
except Exception as exc:
|
|
logger.error(
|
|
"Parallel index failed for doc %s: %s",
|
|
document.id,
|
|
exc,
|
|
exc_info=True,
|
|
)
|
|
async with lock:
|
|
failed_count += 1
|
|
return exc
|
|
|
|
tasks = [_index_one(doc) for doc in documents]
|
|
t_parallel = time.perf_counter()
|
|
outcomes = await asyncio.gather(*tasks, return_exceptions=True)
|
|
perf.info(
|
|
"[indexing] index_batch_parallel gather docs=%d concurrency=%d "
|
|
"indexed=%d failed=%d in %.3fs",
|
|
len(documents),
|
|
max_concurrency,
|
|
indexed_count,
|
|
failed_count,
|
|
time.perf_counter() - t_parallel,
|
|
)
|
|
|
|
for outcome in outcomes:
|
|
if isinstance(outcome, Document):
|
|
results.append(outcome)
|
|
elif isinstance(outcome, Exception):
|
|
pass
|
|
|
|
perf.info(
|
|
"[indexing] index_batch_parallel TOTAL input=%d prepared=%d "
|
|
"indexed=%d failed=%d in %.3fs",
|
|
len(connector_docs),
|
|
len(documents),
|
|
indexed_count,
|
|
failed_count,
|
|
time.perf_counter() - t_total,
|
|
)
|
|
return results, indexed_count, failed_count
|