feat(index-cache): add pickle-free blob serialization

This commit is contained in:
CREDO23 2026-06-12 16:48:01 +02:00
parent cf208365b4
commit 59fa4c38c3

View file

@ -0,0 +1,71 @@
"""Serialize an EmbeddingSet to a compact, self-describing blob (no pickle).
Layout: ``MAGIC | uint32 header_len | json header | float32 matrix``. The header
carries the dim, chunk count, and ordered chunk texts; the matrix holds the
summary vector followed by one row per chunk, all float32 for compactness.
"""
from __future__ import annotations
import json
import struct
import numpy as np
from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingSet
# Marker at the start of every blob: "SurfSense EMBeddings, version 1"-> SSEMB1. Lets us
# reject foreign blobs and bump the trailing digit if the layout ever changes.
_MAGIC = b"SSEMB1"
# 4-byte big-endian unsigned int written before the variable-length JSON header,
# so the reader knows where the header ends and the float matrix begins.
_HEADER_LEN = struct.Struct(">I")
def serialize(embedding_set: EmbeddingSet) -> bytes:
summary = np.asarray(embedding_set.summary_embedding, dtype=np.float32).reshape(-1)
dim = int(summary.shape[0])
rows = [summary]
texts: list[str] = []
for chunk in embedding_set.chunks:
vector = np.asarray(chunk.embedding, dtype=np.float32).reshape(-1)
if vector.shape[0] != dim:
raise ValueError("All vectors in an embedding set must share one dimension.")
rows.append(vector)
texts.append(chunk.text)
matrix = np.stack(rows, axis=0)
header = json.dumps(
{"dim": dim, "count": len(texts), "texts": texts}, ensure_ascii=False
).encode("utf-8")
return b"".join(
[_MAGIC, _HEADER_LEN.pack(len(header)), header, matrix.tobytes(order="C")]
)
def deserialize(blob: bytes) -> EmbeddingSet:
view = memoryview(blob)
if bytes(view[: len(_MAGIC)]) != _MAGIC:
raise ValueError("Unrecognized embedding cache blob.")
offset = len(_MAGIC)
(header_len,) = _HEADER_LEN.unpack(view[offset : offset + _HEADER_LEN.size])
offset += _HEADER_LEN.size
header = json.loads(bytes(view[offset : offset + header_len]).decode("utf-8"))
offset += header_len
dim = int(header["dim"])
count = int(header["count"])
texts: list[str] = header["texts"]
matrix = np.frombuffer(view[offset:], dtype=np.float32)
if matrix.shape[0] != (count + 1) * dim:
raise ValueError("Embedding cache blob is truncated or corrupt.")
matrix = matrix.reshape(count + 1, dim)
return EmbeddingSet(
summary_embedding=matrix[0],
chunks=[CachedChunk(text=texts[i], embedding=matrix[i + 1]) for i in range(count)],
)