Add DiskANN index for vec0 virtual table

Add DiskANN graph-based index: builds a Vamana graph with configurable R
(max degree) and L (search list size, separate for insert/query), supports
int8 quantization with rescore, lazy reverse-edge replacement, pre-quantized
query optimization, and insert buffer reuse. Includes shadow table management,
delete support, KNN integration, compile flag (SQLITE_VEC_ENABLE_DISKANN),
release-demo workflow, fuzz targets, and tests. Fixes rescore int8
quantization bug.
This commit is contained in:
Alex Garcia 2026-03-29 19:46:53 -07:00
parent e2c38f387c
commit 575371d751
23 changed files with 6550 additions and 135 deletions

View file

@ -1187,6 +1187,7 @@ void test_ivf_quantize_binary() {
}
void test_ivf_config_parsing() {
void test_vec0_parse_vector_column_diskann() {
printf("Starting %s...\n", __func__);
struct VectorColumnDefinition col;
int rc;
@ -1199,6 +1200,34 @@ void test_ivf_config_parsing() {
assert(col.index_type == VEC0_INDEX_TYPE_RESCORE);
assert(col.rescore.quantizer_type == VEC0_RESCORE_QUANTIZER_BIT);
assert(col.rescore.oversample == 8); // default
// Existing syntax (no INDEXED BY) should have diskann.enabled == 0
{
const char *input = "emb float[128]";
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
assert(rc == SQLITE_OK);
assert(col.index_type != VEC0_INDEX_TYPE_DISKANN);
sqlite3_free(col.name);
}
// With distance_metric but no INDEXED BY
{
const char *input = "emb float[128] distance_metric=cosine";
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
assert(rc == SQLITE_OK);
assert(col.index_type != VEC0_INDEX_TYPE_DISKANN);
assert(col.distance_metric == VEC0_DISTANCE_METRIC_COSINE);
sqlite3_free(col.name);
}
// Basic binary quantizer
{
const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary)";
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
assert(rc == SQLITE_OK);
assert(col.index_type == VEC0_INDEX_TYPE_DISKANN);
assert(col.diskann.quantizer_type == VEC0_DISKANN_QUANTIZER_BINARY);
assert(col.diskann.n_neighbors == 72); // default
assert(col.diskann.search_list_size == 128); // default
assert(col.dimensions == 128);
sqlite3_free(col.name);
}
@ -1370,6 +1399,681 @@ void test_ivf_config_parsing() {
printf(" All ivf_config_parsing tests passed.\n");
}
#endif /* SQLITE_VEC_ENABLE_IVF */
// INT8 quantizer
{
const char *input = "v float[64] INDEXED BY diskann(neighbor_quantizer=int8)";
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
assert(rc == SQLITE_OK);
assert(col.index_type == VEC0_INDEX_TYPE_DISKANN);
assert(col.diskann.quantizer_type == VEC0_DISKANN_QUANTIZER_INT8);
sqlite3_free(col.name);
}
// Custom n_neighbors
{
const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=48)";
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
assert(rc == SQLITE_OK);
assert(col.index_type == VEC0_INDEX_TYPE_DISKANN);
assert(col.diskann.n_neighbors == 48);
sqlite3_free(col.name);
}
// Custom search_list_size
{
const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, search_list_size=256)";
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
assert(rc == SQLITE_OK);
assert(col.diskann.search_list_size == 256);
sqlite3_free(col.name);
}
// Combined with distance_metric (distance_metric first)
{
const char *input = "emb float[128] distance_metric=cosine INDEXED BY diskann(neighbor_quantizer=int8)";
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
assert(rc == SQLITE_OK);
assert(col.distance_metric == VEC0_DISTANCE_METRIC_COSINE);
assert(col.index_type == VEC0_INDEX_TYPE_DISKANN);
assert(col.diskann.quantizer_type == VEC0_DISKANN_QUANTIZER_INT8);
sqlite3_free(col.name);
}
// Error: missing neighbor_quantizer (required)
{
const char *input = "emb float[128] INDEXED BY diskann(n_neighbors=72)";
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
assert(rc == SQLITE_ERROR);
}
// Error: empty parens
{
const char *input = "emb float[128] INDEXED BY diskann()";
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
assert(rc == SQLITE_ERROR);
}
// Error: unknown quantizer
{
const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=unknown)";
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
assert(rc == SQLITE_ERROR);
}
// Error: bad n_neighbors (not divisible by 8)
{
const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=13)";
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
assert(rc == SQLITE_ERROR);
}
// Error: n_neighbors too large
{
const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=512)";
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
assert(rc == SQLITE_ERROR);
}
// Error: missing BY
{
const char *input = "emb float[128] INDEXED diskann(neighbor_quantizer=binary)";
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
assert(rc == SQLITE_ERROR);
}
// Error: unknown algorithm
{
const char *input = "emb float[128] INDEXED BY hnsw(neighbor_quantizer=binary)";
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
assert(rc == SQLITE_ERROR);
}
// Error: unknown option key
{
const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, foobar=baz)";
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
assert(rc == SQLITE_ERROR);
}
// Case insensitivity for keywords
{
const char *input = "emb float[128] indexed by DISKANN(NEIGHBOR_QUANTIZER=BINARY)";
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
assert(rc == SQLITE_OK);
assert(col.index_type == VEC0_INDEX_TYPE_DISKANN);
assert(col.diskann.quantizer_type == VEC0_DISKANN_QUANTIZER_BINARY);
sqlite3_free(col.name);
}
// Split search_list_size: search and insert
{
const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, search_list_size_search=256, search_list_size_insert=64)";
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
assert(rc == SQLITE_OK);
assert(col.diskann.search_list_size == 128); // default (unified)
assert(col.diskann.search_list_size_search == 256);
assert(col.diskann.search_list_size_insert == 64);
sqlite3_free(col.name);
}
// Split search_list_size: only search
{
const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, search_list_size_search=200)";
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
assert(rc == SQLITE_OK);
assert(col.diskann.search_list_size_search == 200);
assert(col.diskann.search_list_size_insert == 0);
sqlite3_free(col.name);
}
// Error: cannot mix search_list_size with search_list_size_search
{
const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, search_list_size=128, search_list_size_search=256)";
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
assert(rc == SQLITE_ERROR);
}
// Error: cannot mix search_list_size with search_list_size_insert
{
const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, search_list_size=128, search_list_size_insert=64)";
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
assert(rc == SQLITE_ERROR);
}
printf(" All vec0_parse_vector_column_diskann tests passed.\n");
}
void test_diskann_validity_bitmap() {
printf("Starting %s...\n", __func__);
unsigned char validity[3]; // 24 bits
memset(validity, 0, sizeof(validity));
// All initially invalid
for (int i = 0; i < 24; i++) {
assert(diskann_validity_get(validity, i) == 0);
}
assert(diskann_validity_count(validity, 24) == 0);
// Set bit 0
diskann_validity_set(validity, 0, 1);
assert(diskann_validity_get(validity, 0) == 1);
assert(diskann_validity_count(validity, 24) == 1);
// Set bit 7 (last bit of first byte)
diskann_validity_set(validity, 7, 1);
assert(diskann_validity_get(validity, 7) == 1);
assert(diskann_validity_count(validity, 24) == 2);
// Set bit 8 (first bit of second byte)
diskann_validity_set(validity, 8, 1);
assert(diskann_validity_get(validity, 8) == 1);
assert(diskann_validity_count(validity, 24) == 3);
// Set bit 23 (last bit)
diskann_validity_set(validity, 23, 1);
assert(diskann_validity_get(validity, 23) == 1);
assert(diskann_validity_count(validity, 24) == 4);
// Clear bit 0
diskann_validity_set(validity, 0, 0);
assert(diskann_validity_get(validity, 0) == 0);
assert(diskann_validity_count(validity, 24) == 3);
// Other bits unaffected
assert(diskann_validity_get(validity, 7) == 1);
assert(diskann_validity_get(validity, 8) == 1);
printf(" All diskann_validity_bitmap tests passed.\n");
}
void test_diskann_neighbor_ids() {
printf("Starting %s...\n", __func__);
unsigned char ids[8 * 8]; // 8 slots * 8 bytes each
memset(ids, 0, sizeof(ids));
// Set and get slot 0
diskann_neighbor_id_set(ids, 0, 42);
assert(diskann_neighbor_id_get(ids, 0) == 42);
// Set and get middle slot
diskann_neighbor_id_set(ids, 3, 12345);
assert(diskann_neighbor_id_get(ids, 3) == 12345);
// Set and get last slot
diskann_neighbor_id_set(ids, 7, 99999);
assert(diskann_neighbor_id_get(ids, 7) == 99999);
// Slot 0 still correct
assert(diskann_neighbor_id_get(ids, 0) == 42);
// Large value
diskann_neighbor_id_set(ids, 1, INT64_MAX);
assert(diskann_neighbor_id_get(ids, 1) == INT64_MAX);
printf(" All diskann_neighbor_ids tests passed.\n");
}
void test_diskann_quantize_binary() {
printf("Starting %s...\n", __func__);
// 8-dimensional vector: positive values -> 1, negative/zero -> 0
float src[8] = {1.0f, -1.0f, 0.5f, 0.0f, -0.5f, 0.1f, -0.1f, 100.0f};
unsigned char out[1]; // 8 bits = 1 byte
int rc = diskann_quantize_vector(src, 8, VEC0_DISKANN_QUANTIZER_BINARY, out);
assert(rc == 0);
// Expected bits (LSB first within each byte):
// bit 0: 1.0 > 0 -> 1
// bit 1: -1.0 > 0 -> 0
// bit 2: 0.5 > 0 -> 1
// bit 3: 0.0 > 0 -> 0 (not strictly greater)
// bit 4: -0.5 > 0 -> 0
// bit 5: 0.1 > 0 -> 1
// bit 6: -0.1 > 0 -> 0
// bit 7: 100.0 > 0 -> 1
// Expected byte: 1 + 0 + 4 + 0 + 0 + 32 + 0 + 128 = 0b10100101 = 0xA5
assert(out[0] == 0xA5);
printf(" All diskann_quantize_binary tests passed.\n");
}
void test_diskann_node_init_sizes() {
printf("Starting %s...\n", __func__);
unsigned char *validity, *ids, *qvecs;
int validitySize, idsSize, qvecsSize;
// 72 neighbors, binary quantizer, 1024 dims
int rc = diskann_node_init(72, VEC0_DISKANN_QUANTIZER_BINARY, 1024,
&validity, &validitySize, &ids, &idsSize, &qvecs, &qvecsSize);
assert(rc == 0);
assert(validitySize == 9); // 72/8
assert(idsSize == 576); // 72 * 8
assert(qvecsSize == 9216); // 72 * (1024/8)
// All validity bits should be 0
assert(diskann_validity_count(validity, 72) == 0);
sqlite3_free(validity);
sqlite3_free(ids);
sqlite3_free(qvecs);
// 8 neighbors, int8 quantizer, 32 dims
rc = diskann_node_init(8, VEC0_DISKANN_QUANTIZER_INT8, 32,
&validity, &validitySize, &ids, &idsSize, &qvecs, &qvecsSize);
assert(rc == 0);
assert(validitySize == 1); // 8/8
assert(idsSize == 64); // 8 * 8
assert(qvecsSize == 256); // 8 * 32
sqlite3_free(validity);
sqlite3_free(ids);
sqlite3_free(qvecs);
printf(" All diskann_node_init_sizes tests passed.\n");
}
void test_diskann_node_set_clear_neighbor() {
printf("Starting %s...\n", __func__);
unsigned char *validity, *ids, *qvecs;
int validitySize, idsSize, qvecsSize;
// 8 neighbors, binary quantizer, 16 dims (2 bytes per qvec)
int rc = diskann_node_init(8, VEC0_DISKANN_QUANTIZER_BINARY, 16,
&validity, &validitySize, &ids, &idsSize, &qvecs, &qvecsSize);
assert(rc == 0);
// Create a test quantized vector (2 bytes)
unsigned char test_qvec[2] = {0xAB, 0xCD};
// Set neighbor at slot 3
diskann_node_set_neighbor(validity, ids, qvecs, 3,
42, test_qvec, VEC0_DISKANN_QUANTIZER_BINARY, 16);
// Verify slot 3 is valid
assert(diskann_validity_get(validity, 3) == 1);
assert(diskann_validity_count(validity, 8) == 1);
// Verify rowid
assert(diskann_neighbor_id_get(ids, 3) == 42);
// Verify quantized vector
const unsigned char *read_qvec = diskann_neighbor_qvec_get(
qvecs, 3, VEC0_DISKANN_QUANTIZER_BINARY, 16);
assert(read_qvec[0] == 0xAB);
assert(read_qvec[1] == 0xCD);
// Clear slot 3
diskann_node_clear_neighbor(validity, ids, qvecs, 3,
VEC0_DISKANN_QUANTIZER_BINARY, 16);
assert(diskann_validity_get(validity, 3) == 0);
assert(diskann_neighbor_id_get(ids, 3) == 0);
assert(diskann_validity_count(validity, 8) == 0);
sqlite3_free(validity);
sqlite3_free(ids);
sqlite3_free(qvecs);
printf(" All diskann_node_set_clear_neighbor tests passed.\n");
}
void test_diskann_prune_select() {
printf("Starting %s...\n", __func__);
// Scenario: 5 candidates, sorted by distance to p
// Candidates: A(0), B(1), C(2), D(3), E(4)
// p_distances (already sorted): A=1.0, B=2.0, C=3.0, D=4.0, E=5.0
//
// Inter-candidate distances (5x5 matrix):
// A B C D E
// A 0.0 1.5 3.0 4.0 5.0
// B 1.5 0.0 1.5 3.0 4.0
// C 3.0 1.5 0.0 1.5 3.0
// D 4.0 3.0 1.5 0.0 1.5
// E 5.0 4.0 3.0 1.5 0.0
float p_distances[5] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
float inter[25] = {
0.0f, 1.5f, 3.0f, 4.0f, 5.0f,
1.5f, 0.0f, 1.5f, 3.0f, 4.0f,
3.0f, 1.5f, 0.0f, 1.5f, 3.0f,
4.0f, 3.0f, 1.5f, 0.0f, 1.5f,
5.0f, 4.0f, 3.0f, 1.5f, 0.0f,
};
int selected[5];
int count;
// alpha=1.0, R=3: greedy selection
// Round 1: Pick A (closest). Prune check:
// B: 1.0*1.5 <= 2.0? yes -> pruned
// C: 1.0*3.0 <= 3.0? yes -> pruned
// D: 1.0*4.0 <= 4.0? yes -> pruned
// E: 1.0*5.0 <= 5.0? yes -> pruned
// Result: only A selected
{
int rc = diskann_prune_select(inter, p_distances, 5, 1.0f, 3, selected, &count);
assert(rc == 0);
assert(count == 1);
assert(selected[0] == 1); // A
}
// alpha=1.5, R=3: diversity-aware
// Round 1: Pick A. Prune check:
// B: 1.5*1.5=2.25 <= 2.0? no -> keep
// C: 1.5*3.0=4.5 <= 3.0? no -> keep
// D: 1.5*4.0=6.0 <= 4.0? no -> keep
// E: 1.5*5.0=7.5 <= 5.0? no -> keep
// Round 2: Pick B. Prune check:
// C: 1.5*1.5=2.25 <= 3.0? yes -> pruned
// D: 1.5*3.0=4.5 <= 4.0? no -> keep
// E: 1.5*4.0=6.0 <= 5.0? no -> keep
// Round 3: Pick D. Done, 3 selected.
{
int rc = diskann_prune_select(inter, p_distances, 5, 1.5f, 3, selected, &count);
assert(rc == 0);
assert(count == 3);
assert(selected[0] == 1); // A
assert(selected[1] == 1); // B
assert(selected[3] == 1); // D
assert(selected[2] == 0); // C pruned
assert(selected[4] == 0); // E not reached
}
// R > num_candidates with very high alpha (no pruning): select all
{
int rc = diskann_prune_select(inter, p_distances, 5, 100.0f, 10, selected, &count);
assert(rc == 0);
assert(count == 5);
}
// Empty candidate set
{
int rc = diskann_prune_select(NULL, NULL, 0, 1.2f, 3, selected, &count);
assert(rc == 0);
assert(count == 0);
}
printf(" All diskann_prune_select tests passed.\n");
}
void test_diskann_quantized_vector_byte_size() {
printf("Starting %s...\n", __func__);
// Binary quantizer: 1 bit per dimension, so 128 dims = 16 bytes
assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_BINARY, 128) == 16);
assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_BINARY, 8) == 1);
assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_BINARY, 1024) == 128);
// INT8 quantizer: 1 byte per dimension
assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_INT8, 128) == 128);
assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_INT8, 1) == 1);
assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_INT8, 768) == 768);
printf(" All diskann_quantized_vector_byte_size tests passed.\n");
}
void test_diskann_config_defaults() {
printf("Starting %s...\n", __func__);
// A freshly zero-initialized VectorColumnDefinition should have diskann.enabled == 0
struct VectorColumnDefinition col;
memset(&col, 0, sizeof(col));
assert(col.index_type != VEC0_INDEX_TYPE_DISKANN);
assert(col.diskann.n_neighbors == 0);
assert(col.diskann.search_list_size == 0);
// Verify parsing a normal vector column still works and diskann is not enabled
{
const char *input = "embedding float[768]";
int rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
assert(rc == 0 /* SQLITE_OK */);
assert(col.index_type != VEC0_INDEX_TYPE_DISKANN);
sqlite3_free(col.name);
}
printf(" All diskann_config_defaults tests passed.\n");
}
// ======================================================================
// Additional DiskANN unit tests
// ======================================================================
void test_diskann_quantize_int8() {
printf("Starting %s...\n", __func__);
// INT8 quantization uses fixed range [-1, 1]:
// step = 2.0 / 255.0
// out[i] = (i8)((src[i] + 1.0) / step - 128.0)
float src[4] = {-1.0f, 0.0f, 0.5f, 1.0f};
unsigned char out[4];
int rc = diskann_quantize_vector(src, 4, VEC0_DISKANN_QUANTIZER_INT8, out);
assert(rc == 0);
int8_t *signed_out = (int8_t *)out;
// -1.0 -> (0/step) - 128 = -128
assert(signed_out[0] == -128);
// 0.0 -> (1.0/step) - 128 ~= 127.5 - 128 ~= -0.5 -> (i8)(-0.5) = 0
assert(signed_out[1] >= -2 && signed_out[1] <= 2);
// 0.5 -> (1.5/step) - 128 ~= 191.25 - 128 = 63.25 -> (i8) 63
assert(signed_out[2] >= 60 && signed_out[2] <= 66);
// 1.0 -> should be close to 127 (may have float precision issues)
assert(signed_out[3] >= 126 && signed_out[3] <= 127);
printf(" All diskann_quantize_int8 tests passed.\n");
}
void test_diskann_quantize_binary_16d() {
printf("Starting %s...\n", __func__);
// 16-dimensional vector (2 bytes output)
float src[16] = {
1.0f, -1.0f, 0.5f, -0.5f, // byte 0: bit0=1, bit1=0, bit2=1, bit3=0
0.1f, -0.1f, 0.0f, 100.0f, // byte 0: bit4=1, bit5=0, bit6=0, bit7=1
-1.0f, 1.0f, 1.0f, 1.0f, // byte 1: bit0=0, bit1=1, bit2=1, bit3=1
-1.0f, -1.0f, 1.0f, -1.0f // byte 1: bit4=0, bit5=0, bit6=1, bit7=0
};
unsigned char out[2];
int rc = diskann_quantize_vector(src, 16, VEC0_DISKANN_QUANTIZER_BINARY, out);
assert(rc == 0);
// byte 0: bits 0,2,4,7 set -> 0b10010101 = 0x95
assert(out[0] == 0x95);
// byte 1: bits 1,2,3,6 set -> 0b01001110 = 0x4E
assert(out[1] == 0x4E);
printf(" All diskann_quantize_binary_16d tests passed.\n");
}
void test_diskann_quantize_binary_all_positive() {
printf("Starting %s...\n", __func__);
float src[8] = {1.0f, 2.0f, 0.1f, 0.001f, 100.0f, 42.0f, 0.5f, 3.14f};
unsigned char out[1];
int rc = diskann_quantize_vector(src, 8, VEC0_DISKANN_QUANTIZER_BINARY, out);
assert(rc == 0);
assert(out[0] == 0xFF); // All bits set
printf(" All diskann_quantize_binary_all_positive tests passed.\n");
}
void test_diskann_quantize_binary_all_negative() {
printf("Starting %s...\n", __func__);
float src[8] = {-1.0f, -2.0f, -0.1f, -0.001f, -100.0f, -42.0f, -0.5f, 0.0f};
unsigned char out[1];
int rc = diskann_quantize_vector(src, 8, VEC0_DISKANN_QUANTIZER_BINARY, out);
assert(rc == 0);
assert(out[0] == 0x00); // No bits set (all <= 0)
printf(" All diskann_quantize_binary_all_negative tests passed.\n");
}
void test_diskann_candidate_list_operations() {
printf("Starting %s...\n", __func__);
struct DiskannCandidateList list;
int rc = _test_diskann_candidate_list_init(&list, 5);
assert(rc == 0);
// Insert candidates in non-sorted order
_test_diskann_candidate_list_insert(&list, 10, 3.0f);
_test_diskann_candidate_list_insert(&list, 20, 1.0f);
_test_diskann_candidate_list_insert(&list, 30, 2.0f);
assert(_test_diskann_candidate_list_count(&list) == 3);
// Should be sorted by distance
assert(_test_diskann_candidate_list_rowid(&list, 0) == 20); // dist 1.0
assert(_test_diskann_candidate_list_rowid(&list, 1) == 30); // dist 2.0
assert(_test_diskann_candidate_list_rowid(&list, 2) == 10); // dist 3.0
assert(_test_diskann_candidate_list_distance(&list, 0) == 1.0f);
assert(_test_diskann_candidate_list_distance(&list, 1) == 2.0f);
assert(_test_diskann_candidate_list_distance(&list, 2) == 3.0f);
// Deduplication: inserting same rowid with better distance should update
_test_diskann_candidate_list_insert(&list, 10, 0.5f);
assert(_test_diskann_candidate_list_count(&list) == 3); // Same count
assert(_test_diskann_candidate_list_rowid(&list, 0) == 10); // Now first
assert(_test_diskann_candidate_list_distance(&list, 0) == 0.5f);
// Next unvisited: should be index 0
int idx = _test_diskann_candidate_list_next_unvisited(&list);
assert(idx == 0);
// Mark visited
_test_diskann_candidate_list_set_visited(&list, 0);
idx = _test_diskann_candidate_list_next_unvisited(&list);
assert(idx == 1); // Skip visited
// Fill to capacity (5) and try inserting a worse candidate
_test_diskann_candidate_list_insert(&list, 40, 4.0f);
_test_diskann_candidate_list_insert(&list, 50, 5.0f);
assert(_test_diskann_candidate_list_count(&list) == 5);
// Insert worse than worst -> should be discarded
int inserted = _test_diskann_candidate_list_insert(&list, 60, 10.0f);
assert(inserted == 0);
assert(_test_diskann_candidate_list_count(&list) == 5);
// Insert better than worst -> should replace worst
inserted = _test_diskann_candidate_list_insert(&list, 60, 3.5f);
assert(inserted == 1);
assert(_test_diskann_candidate_list_count(&list) == 5);
_test_diskann_candidate_list_free(&list);
printf(" All diskann_candidate_list_operations tests passed.\n");
}
void test_diskann_visited_set_operations() {
printf("Starting %s...\n", __func__);
struct DiskannVisitedSet set;
int rc = _test_diskann_visited_set_init(&set, 32);
assert(rc == 0);
// Empty set
assert(_test_diskann_visited_set_contains(&set, 1) == 0);
assert(_test_diskann_visited_set_contains(&set, 100) == 0);
// Insert and check
int inserted = _test_diskann_visited_set_insert(&set, 42);
assert(inserted == 1);
assert(_test_diskann_visited_set_contains(&set, 42) == 1);
assert(_test_diskann_visited_set_contains(&set, 43) == 0);
// Double insert returns 0
inserted = _test_diskann_visited_set_insert(&set, 42);
assert(inserted == 0);
// Insert several
_test_diskann_visited_set_insert(&set, 1);
_test_diskann_visited_set_insert(&set, 2);
_test_diskann_visited_set_insert(&set, 100);
_test_diskann_visited_set_insert(&set, 999);
assert(_test_diskann_visited_set_contains(&set, 1) == 1);
assert(_test_diskann_visited_set_contains(&set, 2) == 1);
assert(_test_diskann_visited_set_contains(&set, 100) == 1);
assert(_test_diskann_visited_set_contains(&set, 999) == 1);
assert(_test_diskann_visited_set_contains(&set, 3) == 0);
// Sentinel value (rowid 0) should not be insertable
assert(_test_diskann_visited_set_contains(&set, 0) == 0);
inserted = _test_diskann_visited_set_insert(&set, 0);
assert(inserted == 0);
_test_diskann_visited_set_free(&set);
printf(" All diskann_visited_set_operations tests passed.\n");
}
void test_diskann_prune_select_single_candidate() {
printf("Starting %s...\n", __func__);
float p_distances[1] = {5.0f};
float inter[1] = {0.0f};
int selected[1];
int count;
int rc = diskann_prune_select(inter, p_distances, 1, 1.0f, 3, selected, &count);
assert(rc == 0);
assert(count == 1);
assert(selected[0] == 1);
printf(" All diskann_prune_select_single_candidate tests passed.\n");
}
void test_diskann_prune_select_all_identical_distances() {
printf("Starting %s...\n", __func__);
float p_distances[4] = {2.0f, 2.0f, 2.0f, 2.0f};
// All inter-distances are equal too
float inter[16] = {
0.0f, 1.0f, 1.0f, 1.0f,
1.0f, 0.0f, 1.0f, 1.0f,
1.0f, 1.0f, 0.0f, 1.0f,
1.0f, 1.0f, 1.0f, 0.0f,
};
int selected[4];
int count;
// alpha=1.0: pick first, then check if alpha * inter[0][j] <= p_dist[j]
// 1.0 * 1.0 <= 2.0? yes, so all are pruned after picking the first
int rc = diskann_prune_select(inter, p_distances, 4, 1.0f, 4, selected, &count);
assert(rc == 0);
assert(count >= 1); // At least one selected
printf(" All diskann_prune_select_all_identical_distances tests passed.\n");
}
void test_diskann_prune_select_max_neighbors_1() {
printf("Starting %s...\n", __func__);
float p_distances[3] = {1.0f, 2.0f, 3.0f};
float inter[9] = {
0.0f, 5.0f, 5.0f,
5.0f, 0.0f, 5.0f,
5.0f, 5.0f, 0.0f,
};
int selected[3];
int count;
// R=1: should select exactly 1
int rc = diskann_prune_select(inter, p_distances, 3, 1.0f, 1, selected, &count);
assert(rc == 0);
assert(count == 1);
assert(selected[0] == 1); // First (closest) is selected
printf(" All diskann_prune_select_max_neighbors_1 tests passed.\n");
}
int main() {
printf("Starting unit tests...\n");
@ -1402,5 +2106,23 @@ int main() {
test_ivf_quantize_binary();
test_ivf_config_parsing();
#endif
test_vec0_parse_vector_column_diskann();
test_diskann_validity_bitmap();
test_diskann_neighbor_ids();
test_diskann_quantize_binary();
test_diskann_node_init_sizes();
test_diskann_node_set_clear_neighbor();
test_diskann_prune_select();
test_diskann_quantized_vector_byte_size();
test_diskann_config_defaults();
test_diskann_quantize_int8();
test_diskann_quantize_binary_16d();
test_diskann_quantize_binary_all_positive();
test_diskann_quantize_binary_all_negative();
test_diskann_candidate_list_operations();
test_diskann_visited_set_operations();
test_diskann_prune_select_single_candidate();
test_diskann_prune_select_all_identical_distances();
test_diskann_prune_select_max_neighbors_1();
printf("All unit tests passed.\n");
}