refactor some text knn filtering

This commit is contained in:
Alex Garcia 2024-11-18 11:21:49 -08:00
parent 1a216a684d
commit 10a2216845
3 changed files with 101 additions and 105 deletions

View file

@ -5852,6 +5852,104 @@ int vec0_chunks_iter(vec0_vtab * p, const char * idxStr, int argc, sqlite3_value
return rc; return rc;
} }
int vec0_metadata_filter_text(vec0_vtab * p, sqlite3_value * value, const void * buffer, int size, vec0_metadata_operator op, u8* b, int metadata_idx, i64 *rowids) {
int rc;
sqlite3_stmt * stmt = NULL;
const char * target = (const char *) sqlite3_value_text(value);
int targetn = sqlite3_value_bytes(value);
switch(op) {
case VEC0_METADATA_OPERATOR_EQ: {
for(int i = 0; i < size; i++) {
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
int n = ((int*) view)[0];
char * s = (char *) &view[4];
if(n != targetn) {
bitmap_set(b, i, 0);
continue;
}
int prefix_cmp = strncmp(s, target, min(n, 12));
if(n <= 12) {
bitmap_set(b, i, prefix_cmp == 0);
}
// if the prefix doesnt match, the rest of the string wont match
else if(prefix_cmp) {
bitmap_set(b, i, 0);
}
// need to consult
else {
char *slong;
int slongn;
rc = vec0_get_metadata_text_long_value(p, &stmt, metadata_idx, rowids[i], &slongn, &slong);
if(rc != SQLITE_OK) {
goto done;
}
assert(n == slongn);
bitmap_set(b, i, strncmp(slong, target, n) == 0);
}
}
break;
}
case VEC0_METADATA_OPERATOR_NE: {
for(int i = 0; i < size; i++) {
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
int n = ((int*) view)[0];
char * s = (char *) &view[4];
if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */
bitmap_set(b, i, strncmp(s, target, n) != 0);
}
break;
}
case VEC0_METADATA_OPERATOR_GT: {
for(int i = 0; i < size; i++) {
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
int n = ((int*) view)[0];
char * s = (char *) &view[4];
if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */
bitmap_set(b, i, strncmp(s, target, n) > 0);
}
break;
}
case VEC0_METADATA_OPERATOR_GE: {
for(int i = 0; i < size; i++) {
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
int n = ((int*) view)[0];
char * s = (char *) &view[4];
if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */
bitmap_set(b, i, strncmp(s, target, n) >= 0);
}
break;
}
case VEC0_METADATA_OPERATOR_LE: {
for(int i = 0; i < size; i++) {
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
int n = ((int*) view)[0];
char * s = (char *) &view[4];
if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */
bitmap_set(b, i, strncmp(s, target, n) <= 0);
}
break;
}
case VEC0_METADATA_OPERATOR_LT: {
for(int i = 0; i < size; i++) {
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
int n = ((int*) view)[0];
char * s = (char *) &view[4];
if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */
bitmap_set(b, i, strncmp(s, target, n) < 0);
}
break;
}
}
rc = SQLITE_OK;
done:
sqlite3_finalize(stmt);
return rc;
}
/** /**
* @brief Fill in bitmap of chunk values, whether or not the values match a metadata constraint * @brief Fill in bitmap of chunk values, whether or not the values match a metadata constraint
* *
@ -6008,95 +6106,7 @@ int vec0_set_metadata_filter_bitmap(
break; break;
} }
case VEC0_METADATA_COLUMN_KIND_TEXT: { case VEC0_METADATA_COLUMN_KIND_TEXT: {
const char * target = (const char *) sqlite3_value_text(value); vec0_metadata_filter_text(p, value, buffer, size, op, b, metadata_idx, rowids);
int targetn = sqlite3_value_bytes(value);
switch(op) {
case VEC0_METADATA_OPERATOR_EQ: {
sqlite3_stmt * stmt = NULL;
for(int i = 0; i < size; i++) {
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
int n = ((int*) view)[0];
char * s = (char *) &view[4];
if(n != targetn) {
bitmap_set(b, i, 0);
continue;
}
int prefix_cmp = strncmp(s, target, min(n, 12));
if(n <= 12) {
bitmap_set(b, i, prefix_cmp == 0);
}
// if the prefix doesnt match, the rest of the string wont match
else if(prefix_cmp) {
bitmap_set(b, i, 0);
}
// need to consult
else {
char *slong;
int slongn;
rc = vec0_get_metadata_text_long_value(p, &stmt, metadata_idx, rowids[i], &slongn, &slong);
if(rc != SQLITE_OK) {
goto done;
}
assert(n == slongn);
bitmap_set(b, i, strncmp(slong, target, n) == 0);
}
}
sqlite3_finalize(stmt);
break;
}
case VEC0_METADATA_OPERATOR_NE: {
for(int i = 0; i < size; i++) {
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
int n = ((int*) view)[0];
char * s = (char *) &view[4];
if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */
bitmap_set(b, i, strncmp(s, target, n) != 0);
}
break;
}
case VEC0_METADATA_OPERATOR_GT: {
for(int i = 0; i < size; i++) {
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
int n = ((int*) view)[0];
char * s = (char *) &view[4];
if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */
bitmap_set(b, i, strncmp(s, target, n) > 0);
}
break;
}
case VEC0_METADATA_OPERATOR_GE: {
for(int i = 0; i < size; i++) {
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
int n = ((int*) view)[0];
char * s = (char *) &view[4];
if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */
bitmap_set(b, i, strncmp(s, target, n) >= 0);
}
break;
}
case VEC0_METADATA_OPERATOR_LE: {
for(int i = 0; i < size; i++) {
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
int n = ((int*) view)[0];
char * s = (char *) &view[4];
if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */
bitmap_set(b, i, strncmp(s, target, n) <= 0);
}
break;
}
case VEC0_METADATA_OPERATOR_LT: {
for(int i = 0; i < size; i++) {
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
int n = ((int*) view)[0];
char * s = (char *) &view[4];
if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */
bitmap_set(b, i, strncmp(s, target, n) < 0);
}
break;
}
}
break; break;
} }
} }

View file

@ -1434,6 +1434,7 @@
# name: test_stress.1 # name: test_stress.1
OrderedDict({ OrderedDict({
'sql': ''' 'sql': '''
select select
movie_id, movie_id,
title, title,
@ -1448,6 +1449,7 @@
and num_reviews between 100 and 500 and num_reviews between 100 and 500
and mean_rating > 3.5 and mean_rating > 3.5
and k = 5; and k = 5;
''', ''',
'rows': list([ 'rows': list([
OrderedDict({ OrderedDict({
@ -1875,12 +1877,6 @@
]), ]),
}) })
# --- # ---
# name: test_text_knn.10
dict({
'error': 'OperationalError',
'message': 'Could not filter metadata fields',
})
# ---
# name: test_text_knn.2 # name: test_text_knn.2
dict({ dict({
'v_chunks': OrderedDict({ 'v_chunks': OrderedDict({

View file

@ -120,16 +120,6 @@ def test_text_knn(db, snapshot):
== snapshot() == snapshot()
) )
# this break KNN :(
db.execute("insert into v(vector, name) values ('[3.0]', '1234567890123')")
assert (
exec(
db,
"select rowid, name, distance from v where vector match '[.01]' and k = 5 and name != 'aaa'",
)
== snapshot()
)
def test_long_text_updates(db, snapshot): def test_long_text_updates(db, snapshot):
db.execute( db.execute(