#include "../sqlite-vec.h" #include "sqlite-vec-internal.h" #include #include #include #include #define countof(x) (sizeof(x) / sizeof((x)[0])) // Tests vec0_token_next(), the low-level tokenizer that extracts the next // token from a raw char range. Covers every token type (identifier, digit, // brackets, plus, equals), whitespace skipping, EOF on empty/whitespace-only // input, error on unrecognised characters, and boundary behaviour where // identifiers and digits stop at the next non-matching character. void test_vec0_token_next() { printf("Starting %s...\n", __func__); struct Vec0Token token; int rc; char *input; // Single-character tokens input = "+"; rc = vec0_token_next(input, input + 1, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_PLUS); input = "["; rc = vec0_token_next(input, input + 1, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_LBRACKET); input = "]"; rc = vec0_token_next(input, input + 1, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_RBRACKET); input = "="; rc = vec0_token_next(input, input + 1, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_EQ); // Identifier input = "hello"; rc = vec0_token_next(input, input + 5, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_IDENTIFIER); assert(token.start == input); assert(token.end == input + 5); // Identifier with underscores and digits input = "col_1a"; rc = vec0_token_next(input, input + 6, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_IDENTIFIER); assert(token.end - token.start == 6); // Digit sequence input = "1234"; rc = vec0_token_next(input, input + 4, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_DIGIT); assert(token.start == input); assert(token.end == input + 4); // Leading whitespace is skipped input = " abc"; rc = vec0_token_next(input, input + 5, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_IDENTIFIER); assert(token.end - token.start == 3); // Tab/newline whitespace input = "\t\n\r X"; rc = vec0_token_next(input, input + 5, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_IDENTIFIER); // Empty input input = ""; rc = vec0_token_next(input, input, &token); assert(rc == VEC0_TOKEN_RESULT_EOF); // Only whitespace input = " "; rc = vec0_token_next(input, input + 3, &token); assert(rc == VEC0_TOKEN_RESULT_EOF); // Unrecognized character input = "@"; rc = vec0_token_next(input, input + 1, &token); assert(rc == VEC0_TOKEN_RESULT_ERROR); input = "!"; rc = vec0_token_next(input, input + 1, &token); assert(rc == VEC0_TOKEN_RESULT_ERROR); // Identifier stops at bracket input = "foo["; rc = vec0_token_next(input, input + 4, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_IDENTIFIER); assert(token.end - token.start == 3); // Digit stops at non-digit input = "42abc"; rc = vec0_token_next(input, input + 5, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_DIGIT); assert(token.end - token.start == 2); // Left paren input = "("; rc = vec0_token_next(input, input + 1, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_LPAREN); // Right paren input = ")"; rc = vec0_token_next(input, input + 1, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_RPAREN); // Comma input = ","; rc = vec0_token_next(input, input + 1, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_COMMA); printf(" All vec0_token_next tests passed.\n"); } // Tests Vec0Scanner, the stateful wrapper around vec0_token_next() that // tracks position and yields successive tokens. Verifies correct tokenisation // of full sequences like "abc float[128]" and "key=value", empty input, // whitespace-heavy input, and expressions with operators ("a+b"). void test_vec0_scanner() { printf("Starting %s...\n", __func__); struct Vec0Scanner scanner; struct Vec0Token token; int rc; // Scan "abc float[128]" { const char *input = "abc float[128]"; vec0_scanner_init(&scanner, input, (int)strlen(input)); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_IDENTIFIER); assert(token.end - token.start == 3); assert(strncmp(token.start, "abc", 3) == 0); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_IDENTIFIER); assert(token.end - token.start == 5); assert(strncmp(token.start, "float", 5) == 0); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_LBRACKET); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_DIGIT); assert(strncmp(token.start, "128", 3) == 0); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_RBRACKET); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_EOF); } // Scan "key=value" { const char *input = "key=value"; vec0_scanner_init(&scanner, input, (int)strlen(input)); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_IDENTIFIER); assert(strncmp(token.start, "key", 3) == 0); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_EQ); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_IDENTIFIER); assert(strncmp(token.start, "value", 5) == 0); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_EOF); } // Scan empty string { const char *input = ""; vec0_scanner_init(&scanner, input, 0); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_EOF); } // Scan with lots of whitespace { const char *input = " a b "; vec0_scanner_init(&scanner, input, (int)strlen(input)); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_IDENTIFIER); assert(token.end - token.start == 1); assert(*token.start == 'a'); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_IDENTIFIER); assert(token.end - token.start == 1); assert(*token.start == 'b'); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_EOF); } // Scan "a+b" { const char *input = "a+b"; vec0_scanner_init(&scanner, input, (int)strlen(input)); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_IDENTIFIER); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_PLUS); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_IDENTIFIER); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_EOF); } // Scan "diskann(k=v, k2=v2)" { const char *input = "diskann(k=v, k2=v2)"; vec0_scanner_init(&scanner, input, (int)strlen(input)); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_IDENTIFIER); assert(strncmp(token.start, "diskann", 7) == 0); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_LPAREN); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_IDENTIFIER); assert(strncmp(token.start, "k", 1) == 0); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_EQ); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_IDENTIFIER); assert(strncmp(token.start, "v", 1) == 0); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_COMMA); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_IDENTIFIER); assert(strncmp(token.start, "k2", 2) == 0); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_EQ); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_IDENTIFIER); assert(strncmp(token.start, "v2", 2) == 0); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_SOME); assert(token.token_type == TOKEN_TYPE_RPAREN); rc = vec0_scanner_next(&scanner, &token); assert(rc == VEC0_TOKEN_RESULT_EOF); } printf(" All vec0_scanner tests passed.\n"); } // Tests vec0_parse_vector_column(), which parses a vec0 column definition // string like "embedding float[768] distance_metric=cosine" into a // VectorColumnDefinition struct. Covers all element types (float/f32, int8/i8, // bit), column names with underscores/digits, all distance metrics (L2, L1, // cosine), the default metric, and error cases: empty input, missing type, // unknown type, missing dimensions, unknown metric, unknown option key, and // distance_metric on bit columns. void test_vec0_parse_vector_column() { printf("Starting %s...\n", __func__); struct VectorColumnDefinition col; int rc; // Basic float column { const char *input = "embedding float[768]"; rc = vec0_parse_vector_column(input, (int)strlen(input), &col); assert(rc == SQLITE_OK); assert(col.name_length == 9); assert(strncmp(col.name, "embedding", 9) == 0); assert(col.element_type == SQLITE_VEC_ELEMENT_TYPE_FLOAT32); assert(col.dimensions == 768); assert(col.distance_metric == VEC0_DISTANCE_METRIC_L2); sqlite3_free(col.name); } // f32 alias { const char *input = "v f32[3]"; rc = vec0_parse_vector_column(input, (int)strlen(input), &col); assert(rc == SQLITE_OK); assert(col.element_type == SQLITE_VEC_ELEMENT_TYPE_FLOAT32); assert(col.dimensions == 3); sqlite3_free(col.name); } // int8 column { const char *input = "quantized int8[256]"; rc = vec0_parse_vector_column(input, (int)strlen(input), &col); assert(rc == SQLITE_OK); assert(col.element_type == SQLITE_VEC_ELEMENT_TYPE_INT8); assert(col.dimensions == 256); assert(col.name_length == 9); assert(strncmp(col.name, "quantized", 9) == 0); sqlite3_free(col.name); } // i8 alias { const char *input = "q i8[64]"; rc = vec0_parse_vector_column(input, (int)strlen(input), &col); assert(rc == SQLITE_OK); assert(col.element_type == SQLITE_VEC_ELEMENT_TYPE_INT8); assert(col.dimensions == 64); sqlite3_free(col.name); } // bit column { const char *input = "bvec bit[1024]"; rc = vec0_parse_vector_column(input, (int)strlen(input), &col); assert(rc == SQLITE_OK); assert(col.element_type == SQLITE_VEC_ELEMENT_TYPE_BIT); assert(col.dimensions == 1024); sqlite3_free(col.name); } // Column name with underscores and digits { const char *input = "col_name_2 float[10]"; rc = vec0_parse_vector_column(input, (int)strlen(input), &col); assert(rc == SQLITE_OK); assert(col.name_length == 10); assert(strncmp(col.name, "col_name_2", 10) == 0); sqlite3_free(col.name); } // distance_metric=cosine { const char *input = "emb float[128] distance_metric=cosine"; rc = vec0_parse_vector_column(input, (int)strlen(input), &col); assert(rc == SQLITE_OK); assert(col.distance_metric == VEC0_DISTANCE_METRIC_COSINE); assert(col.dimensions == 128); sqlite3_free(col.name); } // distance_metric=L2 (explicit) { const char *input = "emb float[128] distance_metric=L2"; rc = vec0_parse_vector_column(input, (int)strlen(input), &col); assert(rc == SQLITE_OK); assert(col.distance_metric == VEC0_DISTANCE_METRIC_L2); sqlite3_free(col.name); } // distance_metric=L1 { const char *input = "emb float[128] distance_metric=l1"; rc = vec0_parse_vector_column(input, (int)strlen(input), &col); assert(rc == SQLITE_OK); assert(col.distance_metric == VEC0_DISTANCE_METRIC_L1); sqlite3_free(col.name); } // SQLITE_EMPTY: empty string { const char *input = ""; rc = vec0_parse_vector_column(input, 0, &col); assert(rc == SQLITE_EMPTY); } // SQLITE_EMPTY: non-vector column (text primary key) { const char *input = "document_id text primary key"; rc = vec0_parse_vector_column(input, (int)strlen(input), &col); assert(rc == SQLITE_EMPTY); } // SQLITE_EMPTY: non-vector column (partition key) { const char *input = "user_id integer partition key"; rc = vec0_parse_vector_column(input, (int)strlen(input), &col); assert(rc == SQLITE_EMPTY); } // SQLITE_EMPTY: no type (single identifier) { const char *input = "emb"; rc = vec0_parse_vector_column(input, (int)strlen(input), &col); assert(rc == SQLITE_EMPTY); } // SQLITE_EMPTY: unknown type { const char *input = "emb double[128]"; rc = vec0_parse_vector_column(input, (int)strlen(input), &col); assert(rc == SQLITE_EMPTY); } // SQLITE_EMPTY: unknown type (unknowntype) { const char *input = "v unknowntype[128]"; rc = vec0_parse_vector_column(input, (int)strlen(input), &col); assert(rc == SQLITE_EMPTY); } // SQLITE_EMPTY: missing brackets entirely { const char *input = "emb float"; rc = vec0_parse_vector_column(input, (int)strlen(input), &col); assert(rc == SQLITE_EMPTY); } // Error: zero dimensions { const char *input = "v float[0]"; rc = vec0_parse_vector_column(input, (int)strlen(input), &col); assert(rc == SQLITE_ERROR); } // Error: empty brackets (no dimensions) { const char *input = "v float[]"; rc = vec0_parse_vector_column(input, (int)strlen(input), &col); assert(rc == SQLITE_ERROR); } // Error: unknown distance metric { const char *input = "emb float[128] distance_metric=hamming"; rc = vec0_parse_vector_column(input, (int)strlen(input), &col); assert(rc == SQLITE_ERROR); } // Error: unknown distance metric (foo) { const char *input = "v float[128] distance_metric=foo"; rc = vec0_parse_vector_column(input, (int)strlen(input), &col); assert(rc == SQLITE_ERROR); } // Error: unknown option key { const char *input = "emb float[128] foobar=baz"; rc = vec0_parse_vector_column(input, (int)strlen(input), &col); assert(rc == SQLITE_ERROR); } // Error: distance_metric on bit type { const char *input = "emb bit[64] distance_metric=cosine"; rc = vec0_parse_vector_column(input, (int)strlen(input), &col); assert(rc == SQLITE_ERROR); } printf(" All vec0_parse_vector_column tests passed.\n"); } // Tests vec0_parse_partition_key_definition(), which parses a vec0 partition // key column definition like "user_id integer partition key". Verifies correct // parsing of integer and text partition keys, column name extraction, and // rejection of invalid inputs: empty strings, non-partition-key definitions // ("primary key"), and misspelled keywords. void test_vec0_parse_partition_key_definition() { printf("Starting %s...\n", __func__); typedef struct { char * test; int expected_rc; const char *expected_column_name; int expected_column_type; } TestCase; TestCase suite[] = { {"user_id integer partition key", SQLITE_OK, "user_id", SQLITE_INTEGER}, {"USER_id int partition key", SQLITE_OK, "USER_id", SQLITE_INTEGER}, {"category text partition key", SQLITE_OK, "category", SQLITE_TEXT}, {"", SQLITE_EMPTY, "", 0}, {"document_id text primary key", SQLITE_EMPTY, "", 0}, {"document_id text partition keyy", SQLITE_EMPTY, "", 0}, }; for(int i = 0; i < countof(suite); i++) { char * out_column_name; int out_column_name_length; int out_column_type; int rc; rc = vec0_parse_partition_key_definition( suite[i].test, strlen(suite[i].test), &out_column_name, &out_column_name_length, &out_column_type ); assert(rc == suite[i].expected_rc); if(rc == SQLITE_OK) { assert(out_column_name_length == strlen(suite[i].expected_column_name)); assert(strncmp(out_column_name, suite[i].expected_column_name, out_column_name_length) == 0); assert(out_column_type == suite[i].expected_column_type); } printf(" Passed: \"%s\"\n", suite[i].test); } } void test_distance_l2_sqr_float() { printf("Starting %s...\n", __func__); float d; // Identical vectors: distance = 0 { float a[] = {1.0f, 2.0f, 3.0f}; float b[] = {1.0f, 2.0f, 3.0f}; d = _test_distance_l2_sqr_float(a, b, 3); assert(d == 0.0f); } // Orthogonal unit vectors: sqrt(1+1) = sqrt(2) { float a[] = {1.0f, 0.0f, 0.0f}; float b[] = {0.0f, 1.0f, 0.0f}; d = _test_distance_l2_sqr_float(a, b, 3); assert(fabsf(d - sqrtf(2.0f)) < 1e-6f); } // Known computation: [1,2,3] vs [4,5,6] = sqrt(9+9+9) = sqrt(27) { float a[] = {1.0f, 2.0f, 3.0f}; float b[] = {4.0f, 5.0f, 6.0f}; d = _test_distance_l2_sqr_float(a, b, 3); assert(fabsf(d - sqrtf(27.0f)) < 1e-5f); } // Single dimension: sqrt(16) = 4.0 { float a[] = {3.0f}; float b[] = {7.0f}; d = _test_distance_l2_sqr_float(a, b, 1); assert(d == 4.0f); } printf(" All distance_l2_sqr_float tests passed.\n"); } void test_distance_cosine_float() { printf("Starting %s...\n", __func__); float d; // Identical direction: distance = 0.0 { float a[] = {1.0f, 0.0f}; float b[] = {2.0f, 0.0f}; d = _test_distance_cosine_float(a, b, 2); assert(fabsf(d - 0.0f) < 1e-6f); } // Orthogonal: distance = 1.0 { float a[] = {1.0f, 0.0f}; float b[] = {0.0f, 1.0f}; d = _test_distance_cosine_float(a, b, 2); assert(fabsf(d - 1.0f) < 1e-6f); } // Opposite direction: distance = 2.0 { float a[] = {1.0f, 0.0f}; float b[] = {-1.0f, 0.0f}; d = _test_distance_cosine_float(a, b, 2); assert(fabsf(d - 2.0f) < 1e-6f); } printf(" All distance_cosine_float tests passed.\n"); } void test_distance_hamming() { printf("Starting %s...\n", __func__); float d; // Identical bitmaps: distance = 0 { unsigned char a[] = {0xFF}; unsigned char b[] = {0xFF}; d = _test_distance_hamming(a, b, 8); assert(d == 0.0f); } // All different: distance = 8 { unsigned char a[] = {0xFF}; unsigned char b[] = {0x00}; d = _test_distance_hamming(a, b, 8); assert(d == 8.0f); } // Half different: 0xFF vs 0x0F = 4 bits differ { unsigned char a[] = {0xFF}; unsigned char b[] = {0x0F}; d = _test_distance_hamming(a, b, 8); assert(d == 4.0f); } // Multi-byte: [0xFF, 0x00] vs [0x00, 0xFF] = 16 bits differ { unsigned char a[] = {0xFF, 0x00}; unsigned char b[] = {0x00, 0xFF}; d = _test_distance_hamming(a, b, 16); assert(d == 16.0f); } printf(" All distance_hamming tests passed.\n"); } int main() { printf("Starting unit tests...\n"); #ifdef SQLITE_VEC_ENABLE_AVX printf("SQLITE_VEC_ENABLE_AVX=1\n"); #endif #ifdef SQLITE_VEC_ENABLE_NEON printf("SQLITE_VEC_ENABLE_NEON=1\n"); #endif #if !defined(SQLITE_VEC_ENABLE_AVX) && !defined(SQLITE_VEC_ENABLE_NEON) printf("SIMD: none\n"); #endif test_vec0_token_next(); test_vec0_scanner(); test_vec0_parse_vector_column(); test_vec0_parse_partition_key_definition(); test_distance_l2_sqr_float(); test_distance_cosine_float(); test_distance_hamming(); printf("All unit tests passed.\n"); }