diff --git a/surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py b/surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py index 95321a229..58872a219 100644 --- a/surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py +++ b/surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py @@ -18,23 +18,26 @@ from app.indexing_pipeline.cache.eligibility import is_embedding_cacheable from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingKey, EmbeddingSet from app.indexing_pipeline.cache.service import EmbeddingCacheService from app.indexing_pipeline.cache.settings import load_embedding_cache_settings -from app.indexing_pipeline.document_chunker import chunk_text, chunk_text_hybrid +from app.indexing_pipeline.document_chunker import ChunkSlice, chunk_markdown_with_spans from app.indexing_pipeline.document_embedder import embed_texts from app.observability import metrics logger = logging.getLogger(__name__) -ChunkPair = tuple[str, np.ndarray] +SliceEmbedding = tuple[ChunkSlice, np.ndarray] async def build_chunk_embeddings( markdown: str, *, use_code_chunker: bool -) -> tuple[np.ndarray, list[ChunkPair]]: - """Return the document-level vector and ordered ``(chunk_text, vector)`` pairs. +) -> tuple[np.ndarray, list[SliceEmbedding]]: + """Return the document-level vector and ordered ``(ChunkSlice, vector)`` pairs. - Drop-in for the inline chunk+embed step; reuses prior output when the same - markdown has already been embedded with the current model and chunker. + Slices are always recomputed (cheap) so their char spans are exact; only the + embeddings are cached, reused when the same markdown was embedded with the + current model and chunker. """ + slices = await chunk_slices(markdown, use_code_chunker=use_code_chunker) + settings = load_embedding_cache_settings() chunker_kind = "code" if use_code_chunker else "hybrid" embedding_dim = getattr(config.embedding_model_instance, "dimension", None) @@ -45,7 +48,7 @@ async def build_chunk_embeddings( embedding_dim=embedding_dim, ) if not cacheable: - return await _compute(markdown, use_code_chunker=use_code_chunker) + return await _compute(markdown, slices) key = EmbeddingKey( markdown_sha256=_hash_text(markdown), @@ -56,31 +59,30 @@ async def build_chunk_embeddings( ) cached = await _recall(key) - if cached is not None: + if cached is not None and _aligns(cached, slices): metrics.record_embedding_cache_lookup( embedding_model=key.embedding_model, chunker_kind=chunker_kind, outcome="hit", ) logger.debug("Embedding cache hit for %s", key.markdown_sha256) - return cached.summary_embedding, [(c.text, c.embedding) for c in cached.chunks] + return cached.summary_embedding, list( + zip(slices, (c.embedding for c in cached.chunks), strict=True) + ) metrics.record_embedding_cache_lookup( embedding_model=key.embedding_model, chunker_kind=chunker_kind, outcome="miss" ) - summary_embedding, chunk_pairs = await _compute( - markdown, use_code_chunker=use_code_chunker + summary_embedding, pairs = await _compute(markdown, slices) + await _remember(key, summary_embedding, pairs) + return summary_embedding, pairs + + +async def chunk_slices(markdown: str, *, use_code_chunker: bool) -> list[ChunkSlice]: + """Chunk markdown into ordered, char-addressed slices off the event loop.""" + return await asyncio.to_thread( + chunk_markdown_with_spans, markdown, use_code_chunker ) - await _remember(key, summary_embedding, chunk_pairs) - return summary_embedding, chunk_pairs - - -async def chunk_markdown(markdown: str, *, use_code_chunker: bool) -> list[str]: - """Chunk markdown into ordered texts with the pipeline's chunker selection.""" - if use_code_chunker: - return await asyncio.to_thread(chunk_text, markdown, use_code_chunker=True) - # Table-aware hybrid chunker keeps Markdown tables intact (issue #1334). - return await asyncio.to_thread(chunk_text_hybrid, markdown) async def embed_batch(texts: list[str]) -> list[np.ndarray]: @@ -88,13 +90,19 @@ async def embed_batch(texts: list[str]) -> list[np.ndarray]: return await asyncio.to_thread(embed_texts, texts) +def _aligns(cached: EmbeddingSet, slices: list[ChunkSlice]) -> bool: + """A hit is only usable if its texts still match the current chunking.""" + return len(cached.chunks) == len(slices) and all( + c.text == s.text for c, s in zip(cached.chunks, slices, strict=True) + ) + + async def _compute( - markdown: str, *, use_code_chunker: bool -) -> tuple[np.ndarray, list[ChunkPair]]: - chunk_texts = await chunk_markdown(markdown, use_code_chunker=use_code_chunker) - embeddings = await embed_batch([markdown, *chunk_texts]) + markdown: str, slices: list[ChunkSlice] +) -> tuple[np.ndarray, list[SliceEmbedding]]: + embeddings = await embed_batch([markdown, *(s.text for s in slices)]) summary_embedding, *chunk_embeddings = embeddings - return summary_embedding, list(zip(chunk_texts, chunk_embeddings, strict=False)) + return summary_embedding, list(zip(slices, chunk_embeddings, strict=True)) async def _recall(key: EmbeddingKey) -> EmbeddingSet | None: @@ -110,14 +118,14 @@ async def _recall(key: EmbeddingKey) -> EmbeddingSet | None: async def _remember( - key: EmbeddingKey, summary_embedding: np.ndarray, chunk_pairs: list[ChunkPair] + key: EmbeddingKey, summary_embedding: np.ndarray, pairs: list[SliceEmbedding] ) -> None: try: from app.tasks.celery_tasks import get_celery_session_maker embedding_set = EmbeddingSet( summary_embedding=summary_embedding, - chunks=[CachedChunk(text=text, embedding=vec) for text, vec in chunk_pairs], + chunks=[CachedChunk(text=s.text, embedding=vec) for s, vec in pairs], ) async with get_celery_session_maker()() as session: await EmbeddingCacheService(session).remember(key, embedding_set)