mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 16:56:27 +02:00
Update test-unit.c and unittest.rs functions to enforce pre-existing behavior
- Expand sqlite-vec-internal.h with scanner/tokenizer types, vector column definition types, and parser function declarations - Fix min_idx declaration to match actual C signature (add candidates, bTaken, k_used params) - Compile test-unit with -DSQLITE_CORE and link vendor/sqlite3.c so sqlite3 API functions (sqlite3_strnicmp, sqlite3_mprintf, etc.) resolve - Add unit tests for vec0_token_next, Vec0Scanner, and vec0_parse_vector_column - Fix Rust build.rs to define SQLITE_CORE and compile vendor/sqlite3.c - Fix Rust min_idx FFI signature and wrapper to match actual C function Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
0eb855ca67
commit
0659d8848d
5 changed files with 503 additions and 22 deletions
2
Makefile
2
Makefile
|
|
@ -190,7 +190,7 @@ test-loadable-watch:
|
||||||
watchexec --exts c,py,Makefile --clear -- make test-loadable
|
watchexec --exts c,py,Makefile --clear -- make test-loadable
|
||||||
|
|
||||||
test-unit:
|
test-unit:
|
||||||
$(CC) tests/test-unit.c sqlite-vec.c -I./ -Ivendor -o $(prefix)/test-unit && $(prefix)/test-unit
|
$(CC) -DSQLITE_CORE tests/test-unit.c sqlite-vec.c vendor/sqlite3.c -I./ -Ivendor -o $(prefix)/test-unit && $(prefix)/test-unit
|
||||||
|
|
||||||
site-dev:
|
site-dev:
|
||||||
npm --prefix site run dev
|
npm --prefix site run dev
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,12 @@
|
||||||
use std::env;
|
|
||||||
use std::path::{Path, PathBuf};
|
|
||||||
use std::process::Command;
|
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
cc::Build::new()
|
cc::Build::new()
|
||||||
.file("../sqlite-vec.c")
|
.file("../sqlite-vec.c")
|
||||||
.include(".")
|
.file("../vendor/sqlite3.c")
|
||||||
|
.define("SQLITE_CORE", None)
|
||||||
|
.include("../vendor")
|
||||||
|
.include("..")
|
||||||
.static_flag(true)
|
.static_flag(true)
|
||||||
.compile("sqlite-vec-internal");
|
.compile("sqlite-vec-internal");
|
||||||
println!("cargo:rerun-if-changed=usleep.c");
|
|
||||||
println!("cargo:rerun-if-changed=build.rs");
|
println!("cargo:rerun-if-changed=build.rs");
|
||||||
|
println!("cargo:rerun-if-changed=../sqlite-vec.c");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,78 @@
|
||||||
|
#ifndef SQLITE_VEC_INTERNAL_H
|
||||||
|
#define SQLITE_VEC_INTERNAL_H
|
||||||
|
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
int min_idx(
|
int min_idx(
|
||||||
// list of distances, size n
|
|
||||||
const float *distances,
|
const float *distances,
|
||||||
// number of entries in distances
|
|
||||||
int32_t n,
|
int32_t n,
|
||||||
// output array of size k, the indicies of the lowest k values in distances
|
uint8_t *candidates,
|
||||||
int32_t *out,
|
int32_t *out,
|
||||||
// output number of elements
|
int32_t k,
|
||||||
int32_t k
|
uint8_t *bTaken,
|
||||||
|
int32_t *k_used
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Scanner / tokenizer types and functions
|
||||||
|
|
||||||
|
enum Vec0TokenType {
|
||||||
|
TOKEN_TYPE_IDENTIFIER = 0,
|
||||||
|
TOKEN_TYPE_DIGIT = 1,
|
||||||
|
TOKEN_TYPE_LBRACKET = 2,
|
||||||
|
TOKEN_TYPE_RBRACKET = 3,
|
||||||
|
TOKEN_TYPE_PLUS = 4,
|
||||||
|
TOKEN_TYPE_EQ = 5,
|
||||||
|
};
|
||||||
|
|
||||||
|
#define VEC0_TOKEN_RESULT_EOF 1
|
||||||
|
#define VEC0_TOKEN_RESULT_SOME 2
|
||||||
|
#define VEC0_TOKEN_RESULT_ERROR 3
|
||||||
|
|
||||||
|
struct Vec0Token {
|
||||||
|
enum Vec0TokenType token_type;
|
||||||
|
char *start;
|
||||||
|
char *end;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Vec0Scanner {
|
||||||
|
char *start;
|
||||||
|
char *end;
|
||||||
|
char *ptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
void vec0_scanner_init(struct Vec0Scanner *scanner, const char *source, int source_length);
|
||||||
|
int vec0_scanner_next(struct Vec0Scanner *scanner, struct Vec0Token *out);
|
||||||
|
int vec0_token_next(char *start, char *end, struct Vec0Token *out);
|
||||||
|
|
||||||
|
// Vector column definition types and parser
|
||||||
|
|
||||||
|
enum VectorElementType {
|
||||||
|
SQLITE_VEC_ELEMENT_TYPE_FLOAT32 = 223 + 0,
|
||||||
|
SQLITE_VEC_ELEMENT_TYPE_BIT = 223 + 1,
|
||||||
|
SQLITE_VEC_ELEMENT_TYPE_INT8 = 223 + 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum Vec0DistanceMetrics {
|
||||||
|
VEC0_DISTANCE_METRIC_L2 = 1,
|
||||||
|
VEC0_DISTANCE_METRIC_COSINE = 2,
|
||||||
|
VEC0_DISTANCE_METRIC_L1 = 3,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct VectorColumnDefinition {
|
||||||
|
char *name;
|
||||||
|
int name_length;
|
||||||
|
size_t dimensions;
|
||||||
|
enum VectorElementType element_type;
|
||||||
|
enum Vec0DistanceMetrics distance_metric;
|
||||||
|
};
|
||||||
|
|
||||||
|
int vec0_parse_vector_column(const char *source, int source_length,
|
||||||
|
struct VectorColumnDefinition *outColumn);
|
||||||
|
|
||||||
|
int vec0_parse_partition_key_definition(const char *source, int source_length,
|
||||||
|
char **out_column_name,
|
||||||
|
int *out_column_name_length,
|
||||||
|
int *out_column_type);
|
||||||
|
|
||||||
|
#endif /* SQLITE_VEC_INTERNAL_H */
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,398 @@
|
||||||
#include "../sqlite-vec.h"
|
#include "../sqlite-vec.h"
|
||||||
|
#include "sqlite-vec-internal.h"
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
|
||||||
#define countof(x) (sizeof(x) / sizeof((x)[0]))
|
#define countof(x) (sizeof(x) / sizeof((x)[0]))
|
||||||
|
|
||||||
|
// Tests vec0_token_next(), the low-level tokenizer that extracts the next
|
||||||
|
// token from a raw char range. Covers every token type (identifier, digit,
|
||||||
|
// brackets, plus, equals), whitespace skipping, EOF on empty/whitespace-only
|
||||||
|
// input, error on unrecognised characters, and boundary behaviour where
|
||||||
|
// identifiers and digits stop at the next non-matching character.
|
||||||
|
void test_vec0_token_next() {
|
||||||
|
printf("Starting %s...\n", __func__);
|
||||||
|
struct Vec0Token token;
|
||||||
|
int rc;
|
||||||
|
char *input;
|
||||||
|
|
||||||
|
// Single-character tokens
|
||||||
|
input = "+";
|
||||||
|
rc = vec0_token_next(input, input + 1, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_PLUS);
|
||||||
|
|
||||||
|
input = "[";
|
||||||
|
rc = vec0_token_next(input, input + 1, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_LBRACKET);
|
||||||
|
|
||||||
|
input = "]";
|
||||||
|
rc = vec0_token_next(input, input + 1, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_RBRACKET);
|
||||||
|
|
||||||
|
input = "=";
|
||||||
|
rc = vec0_token_next(input, input + 1, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_EQ);
|
||||||
|
|
||||||
|
// Identifier
|
||||||
|
input = "hello";
|
||||||
|
rc = vec0_token_next(input, input + 5, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_IDENTIFIER);
|
||||||
|
assert(token.start == input);
|
||||||
|
assert(token.end == input + 5);
|
||||||
|
|
||||||
|
// Identifier with underscores and digits
|
||||||
|
input = "col_1a";
|
||||||
|
rc = vec0_token_next(input, input + 6, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_IDENTIFIER);
|
||||||
|
assert(token.end - token.start == 6);
|
||||||
|
|
||||||
|
// Digit sequence
|
||||||
|
input = "1234";
|
||||||
|
rc = vec0_token_next(input, input + 4, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_DIGIT);
|
||||||
|
assert(token.start == input);
|
||||||
|
assert(token.end == input + 4);
|
||||||
|
|
||||||
|
// Leading whitespace is skipped
|
||||||
|
input = " abc";
|
||||||
|
rc = vec0_token_next(input, input + 5, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_IDENTIFIER);
|
||||||
|
assert(token.end - token.start == 3);
|
||||||
|
|
||||||
|
// Tab/newline whitespace
|
||||||
|
input = "\t\n\r X";
|
||||||
|
rc = vec0_token_next(input, input + 5, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_IDENTIFIER);
|
||||||
|
|
||||||
|
// Empty input
|
||||||
|
input = "";
|
||||||
|
rc = vec0_token_next(input, input, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_EOF);
|
||||||
|
|
||||||
|
// Only whitespace
|
||||||
|
input = " ";
|
||||||
|
rc = vec0_token_next(input, input + 3, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_EOF);
|
||||||
|
|
||||||
|
// Unrecognized character
|
||||||
|
input = "@";
|
||||||
|
rc = vec0_token_next(input, input + 1, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_ERROR);
|
||||||
|
|
||||||
|
input = "!";
|
||||||
|
rc = vec0_token_next(input, input + 1, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_ERROR);
|
||||||
|
|
||||||
|
// Identifier stops at bracket
|
||||||
|
input = "foo[";
|
||||||
|
rc = vec0_token_next(input, input + 4, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_IDENTIFIER);
|
||||||
|
assert(token.end - token.start == 3);
|
||||||
|
|
||||||
|
// Digit stops at non-digit
|
||||||
|
input = "42abc";
|
||||||
|
rc = vec0_token_next(input, input + 5, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_DIGIT);
|
||||||
|
assert(token.end - token.start == 2);
|
||||||
|
|
||||||
|
printf(" All vec0_token_next tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests Vec0Scanner, the stateful wrapper around vec0_token_next() that
|
||||||
|
// tracks position and yields successive tokens. Verifies correct tokenisation
|
||||||
|
// of full sequences like "abc float[128]" and "key=value", empty input,
|
||||||
|
// whitespace-heavy input, and expressions with operators ("a+b").
|
||||||
|
void test_vec0_scanner() {
|
||||||
|
printf("Starting %s...\n", __func__);
|
||||||
|
struct Vec0Scanner scanner;
|
||||||
|
struct Vec0Token token;
|
||||||
|
int rc;
|
||||||
|
|
||||||
|
// Scan "abc float[128]"
|
||||||
|
{
|
||||||
|
const char *input = "abc float[128]";
|
||||||
|
vec0_scanner_init(&scanner, input, (int)strlen(input));
|
||||||
|
|
||||||
|
rc = vec0_scanner_next(&scanner, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_IDENTIFIER);
|
||||||
|
assert(token.end - token.start == 3);
|
||||||
|
assert(strncmp(token.start, "abc", 3) == 0);
|
||||||
|
|
||||||
|
rc = vec0_scanner_next(&scanner, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_IDENTIFIER);
|
||||||
|
assert(token.end - token.start == 5);
|
||||||
|
assert(strncmp(token.start, "float", 5) == 0);
|
||||||
|
|
||||||
|
rc = vec0_scanner_next(&scanner, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_LBRACKET);
|
||||||
|
|
||||||
|
rc = vec0_scanner_next(&scanner, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_DIGIT);
|
||||||
|
assert(strncmp(token.start, "128", 3) == 0);
|
||||||
|
|
||||||
|
rc = vec0_scanner_next(&scanner, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_RBRACKET);
|
||||||
|
|
||||||
|
rc = vec0_scanner_next(&scanner, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_EOF);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan "key=value"
|
||||||
|
{
|
||||||
|
const char *input = "key=value";
|
||||||
|
vec0_scanner_init(&scanner, input, (int)strlen(input));
|
||||||
|
|
||||||
|
rc = vec0_scanner_next(&scanner, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_IDENTIFIER);
|
||||||
|
assert(strncmp(token.start, "key", 3) == 0);
|
||||||
|
|
||||||
|
rc = vec0_scanner_next(&scanner, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_EQ);
|
||||||
|
|
||||||
|
rc = vec0_scanner_next(&scanner, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_IDENTIFIER);
|
||||||
|
assert(strncmp(token.start, "value", 5) == 0);
|
||||||
|
|
||||||
|
rc = vec0_scanner_next(&scanner, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_EOF);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan empty string
|
||||||
|
{
|
||||||
|
const char *input = "";
|
||||||
|
vec0_scanner_init(&scanner, input, 0);
|
||||||
|
|
||||||
|
rc = vec0_scanner_next(&scanner, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_EOF);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan with lots of whitespace
|
||||||
|
{
|
||||||
|
const char *input = " a b ";
|
||||||
|
vec0_scanner_init(&scanner, input, (int)strlen(input));
|
||||||
|
|
||||||
|
rc = vec0_scanner_next(&scanner, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_IDENTIFIER);
|
||||||
|
assert(token.end - token.start == 1);
|
||||||
|
assert(*token.start == 'a');
|
||||||
|
|
||||||
|
rc = vec0_scanner_next(&scanner, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_IDENTIFIER);
|
||||||
|
assert(token.end - token.start == 1);
|
||||||
|
assert(*token.start == 'b');
|
||||||
|
|
||||||
|
rc = vec0_scanner_next(&scanner, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_EOF);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan "a+b"
|
||||||
|
{
|
||||||
|
const char *input = "a+b";
|
||||||
|
vec0_scanner_init(&scanner, input, (int)strlen(input));
|
||||||
|
|
||||||
|
rc = vec0_scanner_next(&scanner, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_IDENTIFIER);
|
||||||
|
|
||||||
|
rc = vec0_scanner_next(&scanner, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_PLUS);
|
||||||
|
|
||||||
|
rc = vec0_scanner_next(&scanner, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_SOME);
|
||||||
|
assert(token.token_type == TOKEN_TYPE_IDENTIFIER);
|
||||||
|
|
||||||
|
rc = vec0_scanner_next(&scanner, &token);
|
||||||
|
assert(rc == VEC0_TOKEN_RESULT_EOF);
|
||||||
|
}
|
||||||
|
|
||||||
|
printf(" All vec0_scanner tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests vec0_parse_vector_column(), which parses a vec0 column definition
|
||||||
|
// string like "embedding float[768] distance_metric=cosine" into a
|
||||||
|
// VectorColumnDefinition struct. Covers all element types (float/f32, int8/i8,
|
||||||
|
// bit), column names with underscores/digits, all distance metrics (L2, L1,
|
||||||
|
// cosine), the default metric, and error cases: empty input, missing type,
|
||||||
|
// unknown type, missing dimensions, unknown metric, unknown option key, and
|
||||||
|
// distance_metric on bit columns.
|
||||||
|
void test_vec0_parse_vector_column() {
|
||||||
|
printf("Starting %s...\n", __func__);
|
||||||
|
struct VectorColumnDefinition col;
|
||||||
|
int rc;
|
||||||
|
|
||||||
|
// Basic float column
|
||||||
|
{
|
||||||
|
const char *input = "embedding float[768]";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.name_length == 9);
|
||||||
|
assert(strncmp(col.name, "embedding", 9) == 0);
|
||||||
|
assert(col.element_type == SQLITE_VEC_ELEMENT_TYPE_FLOAT32);
|
||||||
|
assert(col.dimensions == 768);
|
||||||
|
assert(col.distance_metric == VEC0_DISTANCE_METRIC_L2);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// f32 alias
|
||||||
|
{
|
||||||
|
const char *input = "v f32[3]";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.element_type == SQLITE_VEC_ELEMENT_TYPE_FLOAT32);
|
||||||
|
assert(col.dimensions == 3);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// int8 column
|
||||||
|
{
|
||||||
|
const char *input = "quantized int8[256]";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.element_type == SQLITE_VEC_ELEMENT_TYPE_INT8);
|
||||||
|
assert(col.dimensions == 256);
|
||||||
|
assert(col.name_length == 9);
|
||||||
|
assert(strncmp(col.name, "quantized", 9) == 0);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// i8 alias
|
||||||
|
{
|
||||||
|
const char *input = "q i8[64]";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.element_type == SQLITE_VEC_ELEMENT_TYPE_INT8);
|
||||||
|
assert(col.dimensions == 64);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// bit column
|
||||||
|
{
|
||||||
|
const char *input = "bvec bit[1024]";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.element_type == SQLITE_VEC_ELEMENT_TYPE_BIT);
|
||||||
|
assert(col.dimensions == 1024);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Column name with underscores and digits
|
||||||
|
{
|
||||||
|
const char *input = "col_name_2 float[10]";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.name_length == 10);
|
||||||
|
assert(strncmp(col.name, "col_name_2", 10) == 0);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// distance_metric=cosine
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] distance_metric=cosine";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.distance_metric == VEC0_DISTANCE_METRIC_COSINE);
|
||||||
|
assert(col.dimensions == 128);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// distance_metric=L2 (explicit)
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] distance_metric=L2";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.distance_metric == VEC0_DISTANCE_METRIC_L2);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// distance_metric=L1
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] distance_metric=l1";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.distance_metric == VEC0_DISTANCE_METRIC_L1);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error: empty string
|
||||||
|
{
|
||||||
|
const char *input = "";
|
||||||
|
rc = vec0_parse_vector_column(input, 0, &col);
|
||||||
|
assert(rc != SQLITE_OK);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error: no type
|
||||||
|
{
|
||||||
|
const char *input = "emb";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc != SQLITE_OK);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error: unknown type
|
||||||
|
{
|
||||||
|
const char *input = "emb double[128]";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc != SQLITE_OK);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error: missing dimensions
|
||||||
|
{
|
||||||
|
const char *input = "emb float";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc != SQLITE_OK);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error: unknown distance metric
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] distance_metric=hamming";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc != SQLITE_OK);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error: unknown option key
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] foobar=baz";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc != SQLITE_OK);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error: distance_metric on bit type
|
||||||
|
{
|
||||||
|
const char *input = "emb bit[64] distance_metric=cosine";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc != SQLITE_OK);
|
||||||
|
}
|
||||||
|
|
||||||
|
printf(" All vec0_parse_vector_column tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests vec0_parse_partition_key_definition(), which parses a vec0 partition
|
||||||
|
// key column definition like "user_id integer partition key". Verifies correct
|
||||||
|
// parsing of integer and text partition keys, column name extraction, and
|
||||||
|
// rejection of invalid inputs: empty strings, non-partition-key definitions
|
||||||
|
// ("primary key"), and misspelled keywords.
|
||||||
void test_vec0_parse_partition_key_definition() {
|
void test_vec0_parse_partition_key_definition() {
|
||||||
printf("Starting %s...\n", __func__);
|
printf("Starting %s...\n", __func__);
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
|
@ -35,7 +423,6 @@ void test_vec0_parse_partition_key_definition() {
|
||||||
&out_column_name_length,
|
&out_column_name_length,
|
||||||
&out_column_type
|
&out_column_type
|
||||||
);
|
);
|
||||||
printf("2\n");
|
|
||||||
assert(rc == suite[i].expected_rc);
|
assert(rc == suite[i].expected_rc);
|
||||||
|
|
||||||
if(rc == SQLITE_OK) {
|
if(rc == SQLITE_OK) {
|
||||||
|
|
@ -44,11 +431,15 @@ void test_vec0_parse_partition_key_definition() {
|
||||||
assert(out_column_type == suite[i].expected_column_type);
|
assert(out_column_type == suite[i].expected_column_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
printf("✅ %s\n", suite[i].test);
|
printf(" Passed: \"%s\"\n", suite[i].test);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
printf("Starting unit tests...\n");
|
printf("Starting unit tests...\n");
|
||||||
|
test_vec0_token_next();
|
||||||
|
test_vec0_scanner();
|
||||||
|
test_vec0_parse_vector_column();
|
||||||
test_vec0_parse_partition_key_definition();
|
test_vec0_parse_partition_key_definition();
|
||||||
|
printf("All unit tests passed.\n");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,30 @@
|
||||||
fn main() {
|
fn main() {
|
||||||
println!("Hello, world!");
|
println!("Hello, world!");
|
||||||
println!("{:?}", _min_idx(vec![3.0, 2.0, 1.0], 2));
|
println!("{:?}", _min_idx(vec![3.0, 2.0, 1.0, f32::MAX, f32::MAX, f32::MAX, f32::MAX, f32::MAX], 2));
|
||||||
}
|
}
|
||||||
|
|
||||||
fn _min_idx(distances: Vec<f32>, k: i32) -> Vec<i32> {
|
fn _min_idx(distances: Vec<f32>, k: i32) -> Vec<i32> {
|
||||||
|
let n = distances.len();
|
||||||
|
assert!(n % 8 == 0, "distances.len() must be a multiple of 8");
|
||||||
|
|
||||||
let mut out: Vec<i32> = vec![0; k as usize];
|
let mut out: Vec<i32> = vec![0; k as usize];
|
||||||
|
let bitmap_bytes = n / 8;
|
||||||
|
let mut candidates: Vec<u8> = vec![0xFF; bitmap_bytes];
|
||||||
|
let mut b_taken: Vec<u8> = vec![0; bitmap_bytes];
|
||||||
|
let mut k_used: i32 = 0;
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
min_idx(
|
min_idx(
|
||||||
distances.as_ptr().cast(),
|
distances.as_ptr(),
|
||||||
distances.len() as i32,
|
n as i32,
|
||||||
|
candidates.as_mut_ptr(),
|
||||||
out.as_mut_ptr(),
|
out.as_mut_ptr(),
|
||||||
k,
|
k,
|
||||||
|
b_taken.as_mut_ptr(),
|
||||||
|
&mut k_used,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
out.truncate(k_used as usize);
|
||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -51,7 +62,15 @@ fn _merge_sorted_lists(
|
||||||
|
|
||||||
#[link(name = "sqlite-vec-internal")]
|
#[link(name = "sqlite-vec-internal")]
|
||||||
extern "C" {
|
extern "C" {
|
||||||
fn min_idx(distances: *const f32, n: i32, out: *mut i32, k: i32) -> i32;
|
fn min_idx(
|
||||||
|
distances: *const f32,
|
||||||
|
n: i32,
|
||||||
|
candidates: *mut u8,
|
||||||
|
out: *mut i32,
|
||||||
|
k: i32,
|
||||||
|
b_taken: *mut u8,
|
||||||
|
k_used: *mut i32,
|
||||||
|
) -> i32;
|
||||||
|
|
||||||
fn merge_sorted_lists(
|
fn merge_sorted_lists(
|
||||||
a: *const f32,
|
a: *const f32,
|
||||||
|
|
@ -74,11 +93,17 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_basic() {
|
fn test_basic() {
|
||||||
assert_eq!(_min_idx(vec![1.0, 2.0, 3.0], 3), vec![0, 1, 2]);
|
let pad = |v: &[f32]| -> Vec<f32> {
|
||||||
assert_eq!(_min_idx(vec![3.0, 2.0, 1.0], 3), vec![2, 1, 0]);
|
let mut r = v.to_vec();
|
||||||
|
r.resize(8, f32::MAX);
|
||||||
|
r
|
||||||
|
};
|
||||||
|
|
||||||
assert_eq!(_min_idx(vec![1.0, 2.0, 3.0], 2), vec![0, 1]);
|
assert_eq!(_min_idx(pad(&[1.0, 2.0, 3.0]), 3), vec![0, 1, 2]);
|
||||||
assert_eq!(_min_idx(vec![3.0, 2.0, 1.0], 2), vec![2, 1]);
|
assert_eq!(_min_idx(pad(&[3.0, 2.0, 1.0]), 3), vec![2, 1, 0]);
|
||||||
|
|
||||||
|
assert_eq!(_min_idx(pad(&[1.0, 2.0, 3.0]), 2), vec![0, 1]);
|
||||||
|
assert_eq!(_min_idx(pad(&[3.0, 2.0, 1.0]), 2), vec![2, 1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue