diff --git a/sqlite-vec-diskann.c b/sqlite-vec-diskann.c index 1a5fd2b..7d4da6e 100644 --- a/sqlite-vec-diskann.c +++ b/sqlite-vec-diskann.c @@ -410,9 +410,18 @@ static int diskann_node_read(vec0_vtab *p, int vec_col_idx, i64 rowid, return SQLITE_NOMEM; } - memcpy(v, sqlite3_column_blob(stmt, 0), vs); - memcpy(ids, sqlite3_column_blob(stmt, 1), is); - memcpy(qv, sqlite3_column_blob(stmt, 2), qs); + const void *blobV = sqlite3_column_blob(stmt, 0); + const void *blobIds = sqlite3_column_blob(stmt, 1); + const void *blobQv = sqlite3_column_blob(stmt, 2); + if (!blobV || !blobIds || !blobQv) { + sqlite3_free(v); + sqlite3_free(ids); + sqlite3_free(qv); + return SQLITE_ERROR; + } + memcpy(v, blobV, vs); + memcpy(ids, blobIds, is); + memcpy(qv, blobQv, qs); *outValidity = v; *outValiditySize = vs; *outNeighborIds = ids; *outNeighborIdsSize = is; @@ -480,9 +489,11 @@ static int diskann_vector_read(vec0_vtab *p, int vec_col_idx, i64 rowid, } int sz = sqlite3_column_bytes(stmt, 0); + const void *blob = sqlite3_column_blob(stmt, 0); + if (!blob || sz == 0) return SQLITE_ERROR; void *vec = sqlite3_malloc(sz); if (!vec) return SQLITE_NOMEM; - memcpy(vec, sqlite3_column_blob(stmt, 0), sz); + memcpy(vec, blob, sz); *outVector = vec; *outVectorSize = sz; @@ -1325,6 +1336,7 @@ static int diskann_flush_buffer(vec0_vtab *p, int vec_col_idx) { while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) { i64 rowid = sqlite3_column_int64(stmt, 0); const void *vector = sqlite3_column_blob(stmt, 1); + if (!vector) continue; // Note: vector is already written to _vectors table, so // diskann_insert_graph will skip re-writing it (vector already exists). // We call the graph-only insert path. diff --git a/sqlite-vec-rescore.c b/sqlite-vec-rescore.c index ef4e692..ef0a35c 100644 --- a/sqlite-vec-rescore.c +++ b/sqlite-vec-rescore.c @@ -426,6 +426,10 @@ static int rescore_knn(vec0_vtab *p, vec0_cursor *pCur, unsigned char *chunkValidity = (unsigned char *)sqlite3_column_blob(stmtChunks, 1); i64 *chunkRowids = (i64 *)sqlite3_column_blob(stmtChunks, 2); + if (!chunkValidity || !chunkRowids) { + rc = SQLITE_ERROR; + goto cleanup; + } memset(chunk_distances, 0, p->chunk_size * sizeof(f32)); memset(chunk_topk_idxs, 0, k_oversample * sizeof(i32)); diff --git a/tests/test-diskann.py b/tests/test-diskann.py index 4fad96b..d71769c 100644 --- a/tests/test-diskann.py +++ b/tests/test-diskann.py @@ -1149,3 +1149,30 @@ def test_diskann_large_batch_insert_500(db): distances = [r[1] for r in rows] for i in range(len(distances) - 1): assert distances[i] <= distances[i + 1] + + +def test_corrupt_truncated_node_blob(db): + """KNN should error (not crash) when DiskANN node blob is truncated.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8) + ) + """) + for i in range(5): + vec = [0.0] * 8 + vec[i % 8] = 1.0 + db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i + 1, _f32(vec)]) + + # Corrupt a DiskANN node: truncate neighbor_ids to 1 byte (wrong size) + db.execute( + "UPDATE t_diskann_nodes00 SET neighbor_ids = x'00' WHERE rowid = 1" + ) + + # Should not crash — may return wrong results or error + try: + db.execute( + "SELECT rowid FROM t WHERE emb MATCH ? AND k=3", + [_f32([1, 0, 0, 0, 0, 0, 0, 0])], + ).fetchall() + except sqlite3.OperationalError: + pass # Error is acceptable — crash is not diff --git a/tests/test-rescore.py b/tests/test-rescore.py index 5025857..1dc6cd7 100644 --- a/tests/test-rescore.py +++ b/tests/test-rescore.py @@ -566,3 +566,32 @@ def test_multiple_vector_columns(db): [float_vec([1.0] * 8)], ).fetchall() assert rows[0]["rowid"] == 2 + + +def test_corrupt_zeroblob_validity(db): + """KNN should error (not crash) when rescore chunk rowids blob is zeroed out.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (1, ?)", + [float_vec([1, 0, 0, 0, 0, 0, 0, 0])], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (2, ?)", + [float_vec([0, 1, 0, 0, 0, 0, 0, 0])], + ) + + # Corrupt: replace rowids with a truncated blob (wrong size) + db.execute("UPDATE t_chunks SET rowids = x'00'") + + # Should not crash — may return wrong results or error + try: + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", + [float_vec([1, 0, 0, 0, 0, 0, 0, 0])], + ).fetchall() + except sqlite3.OperationalError: + pass # Error is acceptable — crash is not