mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-20 21:18:13 +02:00
refactor: make embedding cache span-aware
This commit is contained in:
parent
0ab773cbcd
commit
55491fef9d
1 changed files with 36 additions and 28 deletions
|
|
@ -18,23 +18,26 @@ from app.indexing_pipeline.cache.eligibility import is_embedding_cacheable
|
|||
from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingKey, EmbeddingSet
|
||||
from app.indexing_pipeline.cache.service import EmbeddingCacheService
|
||||
from app.indexing_pipeline.cache.settings import load_embedding_cache_settings
|
||||
from app.indexing_pipeline.document_chunker import chunk_text, chunk_text_hybrid
|
||||
from app.indexing_pipeline.document_chunker import ChunkSlice, chunk_markdown_with_spans
|
||||
from app.indexing_pipeline.document_embedder import embed_texts
|
||||
from app.observability import metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ChunkPair = tuple[str, np.ndarray]
|
||||
SliceEmbedding = tuple[ChunkSlice, np.ndarray]
|
||||
|
||||
|
||||
async def build_chunk_embeddings(
|
||||
markdown: str, *, use_code_chunker: bool
|
||||
) -> tuple[np.ndarray, list[ChunkPair]]:
|
||||
"""Return the document-level vector and ordered ``(chunk_text, vector)`` pairs.
|
||||
) -> tuple[np.ndarray, list[SliceEmbedding]]:
|
||||
"""Return the document-level vector and ordered ``(ChunkSlice, vector)`` pairs.
|
||||
|
||||
Drop-in for the inline chunk+embed step; reuses prior output when the same
|
||||
markdown has already been embedded with the current model and chunker.
|
||||
Slices are always recomputed (cheap) so their char spans are exact; only the
|
||||
embeddings are cached, reused when the same markdown was embedded with the
|
||||
current model and chunker.
|
||||
"""
|
||||
slices = await chunk_slices(markdown, use_code_chunker=use_code_chunker)
|
||||
|
||||
settings = load_embedding_cache_settings()
|
||||
chunker_kind = "code" if use_code_chunker else "hybrid"
|
||||
embedding_dim = getattr(config.embedding_model_instance, "dimension", None)
|
||||
|
|
@ -45,7 +48,7 @@ async def build_chunk_embeddings(
|
|||
embedding_dim=embedding_dim,
|
||||
)
|
||||
if not cacheable:
|
||||
return await _compute(markdown, use_code_chunker=use_code_chunker)
|
||||
return await _compute(markdown, slices)
|
||||
|
||||
key = EmbeddingKey(
|
||||
markdown_sha256=_hash_text(markdown),
|
||||
|
|
@ -56,31 +59,30 @@ async def build_chunk_embeddings(
|
|||
)
|
||||
|
||||
cached = await _recall(key)
|
||||
if cached is not None:
|
||||
if cached is not None and _aligns(cached, slices):
|
||||
metrics.record_embedding_cache_lookup(
|
||||
embedding_model=key.embedding_model,
|
||||
chunker_kind=chunker_kind,
|
||||
outcome="hit",
|
||||
)
|
||||
logger.debug("Embedding cache hit for %s", key.markdown_sha256)
|
||||
return cached.summary_embedding, [(c.text, c.embedding) for c in cached.chunks]
|
||||
return cached.summary_embedding, list(
|
||||
zip(slices, (c.embedding for c in cached.chunks), strict=True)
|
||||
)
|
||||
|
||||
metrics.record_embedding_cache_lookup(
|
||||
embedding_model=key.embedding_model, chunker_kind=chunker_kind, outcome="miss"
|
||||
)
|
||||
summary_embedding, chunk_pairs = await _compute(
|
||||
markdown, use_code_chunker=use_code_chunker
|
||||
summary_embedding, pairs = await _compute(markdown, slices)
|
||||
await _remember(key, summary_embedding, pairs)
|
||||
return summary_embedding, pairs
|
||||
|
||||
|
||||
async def chunk_slices(markdown: str, *, use_code_chunker: bool) -> list[ChunkSlice]:
|
||||
"""Chunk markdown into ordered, char-addressed slices off the event loop."""
|
||||
return await asyncio.to_thread(
|
||||
chunk_markdown_with_spans, markdown, use_code_chunker
|
||||
)
|
||||
await _remember(key, summary_embedding, chunk_pairs)
|
||||
return summary_embedding, chunk_pairs
|
||||
|
||||
|
||||
async def chunk_markdown(markdown: str, *, use_code_chunker: bool) -> list[str]:
|
||||
"""Chunk markdown into ordered texts with the pipeline's chunker selection."""
|
||||
if use_code_chunker:
|
||||
return await asyncio.to_thread(chunk_text, markdown, use_code_chunker=True)
|
||||
# Table-aware hybrid chunker keeps Markdown tables intact (issue #1334).
|
||||
return await asyncio.to_thread(chunk_text_hybrid, markdown)
|
||||
|
||||
|
||||
async def embed_batch(texts: list[str]) -> list[np.ndarray]:
|
||||
|
|
@ -88,13 +90,19 @@ async def embed_batch(texts: list[str]) -> list[np.ndarray]:
|
|||
return await asyncio.to_thread(embed_texts, texts)
|
||||
|
||||
|
||||
def _aligns(cached: EmbeddingSet, slices: list[ChunkSlice]) -> bool:
|
||||
"""A hit is only usable if its texts still match the current chunking."""
|
||||
return len(cached.chunks) == len(slices) and all(
|
||||
c.text == s.text for c, s in zip(cached.chunks, slices, strict=True)
|
||||
)
|
||||
|
||||
|
||||
async def _compute(
|
||||
markdown: str, *, use_code_chunker: bool
|
||||
) -> tuple[np.ndarray, list[ChunkPair]]:
|
||||
chunk_texts = await chunk_markdown(markdown, use_code_chunker=use_code_chunker)
|
||||
embeddings = await embed_batch([markdown, *chunk_texts])
|
||||
markdown: str, slices: list[ChunkSlice]
|
||||
) -> tuple[np.ndarray, list[SliceEmbedding]]:
|
||||
embeddings = await embed_batch([markdown, *(s.text for s in slices)])
|
||||
summary_embedding, *chunk_embeddings = embeddings
|
||||
return summary_embedding, list(zip(chunk_texts, chunk_embeddings, strict=False))
|
||||
return summary_embedding, list(zip(slices, chunk_embeddings, strict=True))
|
||||
|
||||
|
||||
async def _recall(key: EmbeddingKey) -> EmbeddingSet | None:
|
||||
|
|
@ -110,14 +118,14 @@ async def _recall(key: EmbeddingKey) -> EmbeddingSet | None:
|
|||
|
||||
|
||||
async def _remember(
|
||||
key: EmbeddingKey, summary_embedding: np.ndarray, chunk_pairs: list[ChunkPair]
|
||||
key: EmbeddingKey, summary_embedding: np.ndarray, pairs: list[SliceEmbedding]
|
||||
) -> None:
|
||||
try:
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
embedding_set = EmbeddingSet(
|
||||
summary_embedding=summary_embedding,
|
||||
chunks=[CachedChunk(text=text, embedding=vec) for text, vec in chunk_pairs],
|
||||
chunks=[CachedChunk(text=s.text, embedding=vec) for s, vec in pairs],
|
||||
)
|
||||
async with get_celery_session_maker()() as session:
|
||||
await EmbeddingCacheService(session).remember(key, embedding_set)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue