From 0659d8848da2c127d9ba1ebb678f0ba37ee66429 Mon Sep 17 00:00:00 2001 From: Alex Garcia Date: Mon, 2 Mar 2026 17:46:11 -0800 Subject: [PATCH] Update test-unit.c and unittest.rs functions to enforce pre-existing behavior - Expand sqlite-vec-internal.h with scanner/tokenizer types, vector column definition types, and parser function declarations - Fix min_idx declaration to match actual C signature (add candidates, bTaken, k_used params) - Compile test-unit with -DSQLITE_CORE and link vendor/sqlite3.c so sqlite3 API functions (sqlite3_strnicmp, sqlite3_mprintf, etc.) resolve - Add unit tests for vec0_token_next, Vec0Scanner, and vec0_parse_vector_column - Fix Rust build.rs to define SQLITE_CORE and compile vendor/sqlite3.c - Fix Rust min_idx FFI signature and wrapper to match actual C function Co-Authored-By: Claude Opus 4.6 --- Makefile | 2 +- tests/build.rs | 11 +- tests/sqlite-vec-internal.h | 76 ++++++- tests/test-unit.c | 395 +++++++++++++++++++++++++++++++++++- tests/unittest.rs | 41 +++- 5 files changed, 503 insertions(+), 22 deletions(-) diff --git a/Makefile b/Makefile index ee18c9c..21a814d 100644 --- a/Makefile +++ b/Makefile @@ -190,7 +190,7 @@ test-loadable-watch: watchexec --exts c,py,Makefile --clear -- make test-loadable test-unit: - $(CC) tests/test-unit.c sqlite-vec.c -I./ -Ivendor -o $(prefix)/test-unit && $(prefix)/test-unit + $(CC) -DSQLITE_CORE tests/test-unit.c sqlite-vec.c vendor/sqlite3.c -I./ -Ivendor -o $(prefix)/test-unit && $(prefix)/test-unit site-dev: npm --prefix site run dev diff --git a/tests/build.rs b/tests/build.rs index 842cf2e..41d85f4 100644 --- a/tests/build.rs +++ b/tests/build.rs @@ -1,13 +1,12 @@ -use std::env; -use std::path::{Path, PathBuf}; -use std::process::Command; - fn main() { cc::Build::new() .file("../sqlite-vec.c") - .include(".") + .file("../vendor/sqlite3.c") + .define("SQLITE_CORE", None) + .include("../vendor") + .include("..") .static_flag(true) .compile("sqlite-vec-internal"); - println!("cargo:rerun-if-changed=usleep.c"); println!("cargo:rerun-if-changed=build.rs"); + println!("cargo:rerun-if-changed=../sqlite-vec.c"); } diff --git a/tests/sqlite-vec-internal.h b/tests/sqlite-vec-internal.h index d81ab7a..3a1f213 100644 --- a/tests/sqlite-vec-internal.h +++ b/tests/sqlite-vec-internal.h @@ -1,12 +1,78 @@ +#ifndef SQLITE_VEC_INTERNAL_H +#define SQLITE_VEC_INTERNAL_H + #include +#include int min_idx( - // list of distances, size n const float *distances, - // number of entries in distances int32_t n, - // output array of size k, the indicies of the lowest k values in distances + uint8_t *candidates, int32_t *out, - // output number of elements - int32_t k + int32_t k, + uint8_t *bTaken, + int32_t *k_used ); + +// Scanner / tokenizer types and functions + +enum Vec0TokenType { + TOKEN_TYPE_IDENTIFIER = 0, + TOKEN_TYPE_DIGIT = 1, + TOKEN_TYPE_LBRACKET = 2, + TOKEN_TYPE_RBRACKET = 3, + TOKEN_TYPE_PLUS = 4, + TOKEN_TYPE_EQ = 5, +}; + +#define VEC0_TOKEN_RESULT_EOF 1 +#define VEC0_TOKEN_RESULT_SOME 2 +#define VEC0_TOKEN_RESULT_ERROR 3 + +struct Vec0Token { + enum Vec0TokenType token_type; + char *start; + char *end; +}; + +struct Vec0Scanner { + char *start; + char *end; + char *ptr; +}; + +void vec0_scanner_init(struct Vec0Scanner *scanner, const char *source, int source_length); +int vec0_scanner_next(struct Vec0Scanner *scanner, struct Vec0Token *out); +int vec0_token_next(char *start, char *end, struct Vec0Token *out); + +// Vector column definition types and parser + +enum VectorElementType { + SQLITE_VEC_ELEMENT_TYPE_FLOAT32 = 223 + 0, + SQLITE_VEC_ELEMENT_TYPE_BIT = 223 + 1, + SQLITE_VEC_ELEMENT_TYPE_INT8 = 223 + 2, +}; + +enum Vec0DistanceMetrics { + VEC0_DISTANCE_METRIC_L2 = 1, + VEC0_DISTANCE_METRIC_COSINE = 2, + VEC0_DISTANCE_METRIC_L1 = 3, +}; + +struct VectorColumnDefinition { + char *name; + int name_length; + size_t dimensions; + enum VectorElementType element_type; + enum Vec0DistanceMetrics distance_metric; +}; + +int vec0_parse_vector_column(const char *source, int source_length, + struct VectorColumnDefinition *outColumn); + +int vec0_parse_partition_key_definition(const char *source, int source_length, + char **out_column_name, + int *out_column_name_length, + int *out_column_type); + +#endif /* SQLITE_VEC_INTERNAL_H */ diff --git a/tests/test-unit.c b/tests/test-unit.c index d9a1211..9457de8 100644 --- a/tests/test-unit.c +++ b/tests/test-unit.c @@ -1,10 +1,398 @@ #include "../sqlite-vec.h" +#include "sqlite-vec-internal.h" #include #include #include #define countof(x) (sizeof(x) / sizeof((x)[0])) +// Tests vec0_token_next(), the low-level tokenizer that extracts the next +// token from a raw char range. Covers every token type (identifier, digit, +// brackets, plus, equals), whitespace skipping, EOF on empty/whitespace-only +// input, error on unrecognised characters, and boundary behaviour where +// identifiers and digits stop at the next non-matching character. +void test_vec0_token_next() { + printf("Starting %s...\n", __func__); + struct Vec0Token token; + int rc; + char *input; + + // Single-character tokens + input = "+"; + rc = vec0_token_next(input, input + 1, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_PLUS); + + input = "["; + rc = vec0_token_next(input, input + 1, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_LBRACKET); + + input = "]"; + rc = vec0_token_next(input, input + 1, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_RBRACKET); + + input = "="; + rc = vec0_token_next(input, input + 1, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_EQ); + + // Identifier + input = "hello"; + rc = vec0_token_next(input, input + 5, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_IDENTIFIER); + assert(token.start == input); + assert(token.end == input + 5); + + // Identifier with underscores and digits + input = "col_1a"; + rc = vec0_token_next(input, input + 6, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_IDENTIFIER); + assert(token.end - token.start == 6); + + // Digit sequence + input = "1234"; + rc = vec0_token_next(input, input + 4, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_DIGIT); + assert(token.start == input); + assert(token.end == input + 4); + + // Leading whitespace is skipped + input = " abc"; + rc = vec0_token_next(input, input + 5, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_IDENTIFIER); + assert(token.end - token.start == 3); + + // Tab/newline whitespace + input = "\t\n\r X"; + rc = vec0_token_next(input, input + 5, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_IDENTIFIER); + + // Empty input + input = ""; + rc = vec0_token_next(input, input, &token); + assert(rc == VEC0_TOKEN_RESULT_EOF); + + // Only whitespace + input = " "; + rc = vec0_token_next(input, input + 3, &token); + assert(rc == VEC0_TOKEN_RESULT_EOF); + + // Unrecognized character + input = "@"; + rc = vec0_token_next(input, input + 1, &token); + assert(rc == VEC0_TOKEN_RESULT_ERROR); + + input = "!"; + rc = vec0_token_next(input, input + 1, &token); + assert(rc == VEC0_TOKEN_RESULT_ERROR); + + // Identifier stops at bracket + input = "foo["; + rc = vec0_token_next(input, input + 4, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_IDENTIFIER); + assert(token.end - token.start == 3); + + // Digit stops at non-digit + input = "42abc"; + rc = vec0_token_next(input, input + 5, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_DIGIT); + assert(token.end - token.start == 2); + + printf(" All vec0_token_next tests passed.\n"); +} + +// Tests Vec0Scanner, the stateful wrapper around vec0_token_next() that +// tracks position and yields successive tokens. Verifies correct tokenisation +// of full sequences like "abc float[128]" and "key=value", empty input, +// whitespace-heavy input, and expressions with operators ("a+b"). +void test_vec0_scanner() { + printf("Starting %s...\n", __func__); + struct Vec0Scanner scanner; + struct Vec0Token token; + int rc; + + // Scan "abc float[128]" + { + const char *input = "abc float[128]"; + vec0_scanner_init(&scanner, input, (int)strlen(input)); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_IDENTIFIER); + assert(token.end - token.start == 3); + assert(strncmp(token.start, "abc", 3) == 0); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_IDENTIFIER); + assert(token.end - token.start == 5); + assert(strncmp(token.start, "float", 5) == 0); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_LBRACKET); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_DIGIT); + assert(strncmp(token.start, "128", 3) == 0); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_RBRACKET); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_EOF); + } + + // Scan "key=value" + { + const char *input = "key=value"; + vec0_scanner_init(&scanner, input, (int)strlen(input)); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_IDENTIFIER); + assert(strncmp(token.start, "key", 3) == 0); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_EQ); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_IDENTIFIER); + assert(strncmp(token.start, "value", 5) == 0); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_EOF); + } + + // Scan empty string + { + const char *input = ""; + vec0_scanner_init(&scanner, input, 0); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_EOF); + } + + // Scan with lots of whitespace + { + const char *input = " a b "; + vec0_scanner_init(&scanner, input, (int)strlen(input)); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_IDENTIFIER); + assert(token.end - token.start == 1); + assert(*token.start == 'a'); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_IDENTIFIER); + assert(token.end - token.start == 1); + assert(*token.start == 'b'); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_EOF); + } + + // Scan "a+b" + { + const char *input = "a+b"; + vec0_scanner_init(&scanner, input, (int)strlen(input)); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_IDENTIFIER); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_PLUS); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_IDENTIFIER); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_EOF); + } + + printf(" All vec0_scanner tests passed.\n"); +} + +// Tests vec0_parse_vector_column(), which parses a vec0 column definition +// string like "embedding float[768] distance_metric=cosine" into a +// VectorColumnDefinition struct. Covers all element types (float/f32, int8/i8, +// bit), column names with underscores/digits, all distance metrics (L2, L1, +// cosine), the default metric, and error cases: empty input, missing type, +// unknown type, missing dimensions, unknown metric, unknown option key, and +// distance_metric on bit columns. +void test_vec0_parse_vector_column() { + printf("Starting %s...\n", __func__); + struct VectorColumnDefinition col; + int rc; + + // Basic float column + { + const char *input = "embedding float[768]"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.name_length == 9); + assert(strncmp(col.name, "embedding", 9) == 0); + assert(col.element_type == SQLITE_VEC_ELEMENT_TYPE_FLOAT32); + assert(col.dimensions == 768); + assert(col.distance_metric == VEC0_DISTANCE_METRIC_L2); + sqlite3_free(col.name); + } + + // f32 alias + { + const char *input = "v f32[3]"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.element_type == SQLITE_VEC_ELEMENT_TYPE_FLOAT32); + assert(col.dimensions == 3); + sqlite3_free(col.name); + } + + // int8 column + { + const char *input = "quantized int8[256]"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.element_type == SQLITE_VEC_ELEMENT_TYPE_INT8); + assert(col.dimensions == 256); + assert(col.name_length == 9); + assert(strncmp(col.name, "quantized", 9) == 0); + sqlite3_free(col.name); + } + + // i8 alias + { + const char *input = "q i8[64]"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.element_type == SQLITE_VEC_ELEMENT_TYPE_INT8); + assert(col.dimensions == 64); + sqlite3_free(col.name); + } + + // bit column + { + const char *input = "bvec bit[1024]"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.element_type == SQLITE_VEC_ELEMENT_TYPE_BIT); + assert(col.dimensions == 1024); + sqlite3_free(col.name); + } + + // Column name with underscores and digits + { + const char *input = "col_name_2 float[10]"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.name_length == 10); + assert(strncmp(col.name, "col_name_2", 10) == 0); + sqlite3_free(col.name); + } + + // distance_metric=cosine + { + const char *input = "emb float[128] distance_metric=cosine"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.distance_metric == VEC0_DISTANCE_METRIC_COSINE); + assert(col.dimensions == 128); + sqlite3_free(col.name); + } + + // distance_metric=L2 (explicit) + { + const char *input = "emb float[128] distance_metric=L2"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.distance_metric == VEC0_DISTANCE_METRIC_L2); + sqlite3_free(col.name); + } + + // distance_metric=L1 + { + const char *input = "emb float[128] distance_metric=l1"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.distance_metric == VEC0_DISTANCE_METRIC_L1); + sqlite3_free(col.name); + } + + // Error: empty string + { + const char *input = ""; + rc = vec0_parse_vector_column(input, 0, &col); + assert(rc != SQLITE_OK); + } + + // Error: no type + { + const char *input = "emb"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc != SQLITE_OK); + } + + // Error: unknown type + { + const char *input = "emb double[128]"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc != SQLITE_OK); + } + + // Error: missing dimensions + { + const char *input = "emb float"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc != SQLITE_OK); + } + + // Error: unknown distance metric + { + const char *input = "emb float[128] distance_metric=hamming"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc != SQLITE_OK); + } + + // Error: unknown option key + { + const char *input = "emb float[128] foobar=baz"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc != SQLITE_OK); + } + + // Error: distance_metric on bit type + { + const char *input = "emb bit[64] distance_metric=cosine"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc != SQLITE_OK); + } + + printf(" All vec0_parse_vector_column tests passed.\n"); +} + +// Tests vec0_parse_partition_key_definition(), which parses a vec0 partition +// key column definition like "user_id integer partition key". Verifies correct +// parsing of integer and text partition keys, column name extraction, and +// rejection of invalid inputs: empty strings, non-partition-key definitions +// ("primary key"), and misspelled keywords. void test_vec0_parse_partition_key_definition() { printf("Starting %s...\n", __func__); typedef struct { @@ -35,7 +423,6 @@ void test_vec0_parse_partition_key_definition() { &out_column_name_length, &out_column_type ); - printf("2\n"); assert(rc == suite[i].expected_rc); if(rc == SQLITE_OK) { @@ -44,11 +431,15 @@ void test_vec0_parse_partition_key_definition() { assert(out_column_type == suite[i].expected_column_type); } - printf("✅ %s\n", suite[i].test); + printf(" Passed: \"%s\"\n", suite[i].test); } } int main() { printf("Starting unit tests...\n"); + test_vec0_token_next(); + test_vec0_scanner(); + test_vec0_parse_vector_column(); test_vec0_parse_partition_key_definition(); + printf("All unit tests passed.\n"); } diff --git a/tests/unittest.rs b/tests/unittest.rs index 4506b02..7d09856 100644 --- a/tests/unittest.rs +++ b/tests/unittest.rs @@ -1,19 +1,30 @@ fn main() { println!("Hello, world!"); - println!("{:?}", _min_idx(vec![3.0, 2.0, 1.0], 2)); + println!("{:?}", _min_idx(vec![3.0, 2.0, 1.0, f32::MAX, f32::MAX, f32::MAX, f32::MAX, f32::MAX], 2)); } fn _min_idx(distances: Vec, k: i32) -> Vec { + let n = distances.len(); + assert!(n % 8 == 0, "distances.len() must be a multiple of 8"); + let mut out: Vec = vec![0; k as usize]; + let bitmap_bytes = n / 8; + let mut candidates: Vec = vec![0xFF; bitmap_bytes]; + let mut b_taken: Vec = vec![0; bitmap_bytes]; + let mut k_used: i32 = 0; unsafe { min_idx( - distances.as_ptr().cast(), - distances.len() as i32, + distances.as_ptr(), + n as i32, + candidates.as_mut_ptr(), out.as_mut_ptr(), k, + b_taken.as_mut_ptr(), + &mut k_used, ); } + out.truncate(k_used as usize); out } @@ -51,7 +62,15 @@ fn _merge_sorted_lists( #[link(name = "sqlite-vec-internal")] extern "C" { - fn min_idx(distances: *const f32, n: i32, out: *mut i32, k: i32) -> i32; + fn min_idx( + distances: *const f32, + n: i32, + candidates: *mut u8, + out: *mut i32, + k: i32, + b_taken: *mut u8, + k_used: *mut i32, + ) -> i32; fn merge_sorted_lists( a: *const f32, @@ -74,11 +93,17 @@ mod tests { #[test] fn test_basic() { - assert_eq!(_min_idx(vec![1.0, 2.0, 3.0], 3), vec![0, 1, 2]); - assert_eq!(_min_idx(vec![3.0, 2.0, 1.0], 3), vec![2, 1, 0]); + let pad = |v: &[f32]| -> Vec { + let mut r = v.to_vec(); + r.resize(8, f32::MAX); + r + }; - assert_eq!(_min_idx(vec![1.0, 2.0, 3.0], 2), vec![0, 1]); - assert_eq!(_min_idx(vec![3.0, 2.0, 1.0], 2), vec![2, 1]); + assert_eq!(_min_idx(pad(&[1.0, 2.0, 3.0]), 3), vec![0, 1, 2]); + assert_eq!(_min_idx(pad(&[3.0, 2.0, 1.0]), 3), vec![2, 1, 0]); + + assert_eq!(_min_idx(pad(&[1.0, 2.0, 3.0]), 2), vec![0, 1]); + assert_eq!(_min_idx(pad(&[3.0, 2.0, 1.0]), 2), vec![2, 1]); } #[test]