mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 16:56:27 +02:00
text knn GT/GE fixes
This commit is contained in:
parent
1ec1b89f60
commit
df29e31ddc
3 changed files with 318 additions and 30 deletions
49
sqlite-vec.c
49
sqlite-vec.c
|
|
@ -5968,19 +5968,19 @@ int vec0_metadata_filter_text(vec0_vtab * p, sqlite3_value * value, const void *
|
|||
view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
|
||||
nPrefix = ((int*) view)[0];
|
||||
sPrefix = (char *) &view[4];
|
||||
int cmpPrefix = strncmp(sPrefix, sTarget, min(nPrefix, VEC0_METADATA_TEXT_VIEW_DATA_LENGTH));
|
||||
int cmpPrefix = strncmp(sPrefix, sTarget, min(min(nPrefix, VEC0_METADATA_TEXT_VIEW_DATA_LENGTH), nTarget));
|
||||
|
||||
// for short strings, use the prefix comparison direclty
|
||||
if(nPrefix <= VEC0_METADATA_TEXT_VIEW_DATA_LENGTH) {
|
||||
bitmap_set(b, i, cmpPrefix > 0);
|
||||
continue;
|
||||
}
|
||||
|
||||
// for GT, only need to consult full string if EQ
|
||||
if(cmpPrefix != 0) {
|
||||
if(nPrefix < VEC0_METADATA_TEXT_VIEW_DATA_LENGTH) {
|
||||
// if prefix match, check which is longer
|
||||
if(cmpPrefix == 0) {
|
||||
bitmap_set(b, i, nPrefix > nTarget);
|
||||
}
|
||||
else {
|
||||
bitmap_set(b, i, cmpPrefix > 0);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// TODO(perf): may not need to compare full text in some cases
|
||||
|
||||
rc = vec0_get_metadata_text_long_value(p, &stmt, metadata_idx, rowids[i], &nFull, &sFull);
|
||||
if(rc != SQLITE_OK) {
|
||||
|
|
@ -5996,11 +5996,32 @@ int vec0_metadata_filter_text(vec0_vtab * p, sqlite3_value * value, const void *
|
|||
}
|
||||
case VEC0_METADATA_OPERATOR_GE: {
|
||||
for(int i = 0; i < size; i++) {
|
||||
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
|
||||
int n = ((int*) view)[0];
|
||||
char * s = (char *) &view[4];
|
||||
if(n > VEC0_METADATA_TEXT_VIEW_DATA_LENGTH) {rc = SQLITE_ERROR;goto done;} /* TODO */
|
||||
bitmap_set(b, i, strncmp(s, sTarget, n) >= 0);
|
||||
view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
|
||||
nPrefix = ((int*) view)[0];
|
||||
sPrefix = (char *) &view[4];
|
||||
int cmpPrefix = strncmp(sPrefix, sTarget, min(min(nPrefix, VEC0_METADATA_TEXT_VIEW_DATA_LENGTH), nTarget));
|
||||
|
||||
if(nPrefix < VEC0_METADATA_TEXT_VIEW_DATA_LENGTH) {
|
||||
// if prefix match, check which is longer
|
||||
if(cmpPrefix == 0) {
|
||||
bitmap_set(b, i, nPrefix >= nTarget);
|
||||
}
|
||||
else {
|
||||
bitmap_set(b, i, cmpPrefix >= 0);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// TODO(perf): may not need to compare full text in some cases
|
||||
|
||||
rc = vec0_get_metadata_text_long_value(p, &stmt, metadata_idx, rowids[i], &nFull, &sFull);
|
||||
if(rc != SQLITE_OK) {
|
||||
goto done;
|
||||
}
|
||||
if(nPrefix != nFull) {
|
||||
rc = SQLITE_ERROR;
|
||||
goto done;
|
||||
}
|
||||
bitmap_set(b, i, strncmp(sFull, sTarget, nFull) >= 0);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -617,6 +617,13 @@
|
|||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[eq-bb]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name = ?",
|
||||
'rows': list([
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[eq-bbbb]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name = ?",
|
||||
|
|
@ -629,6 +636,13 @@
|
|||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[eq-bbbbbb]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name = ?",
|
||||
'rows': list([
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[eq-bbbbbbbbbbbb_aaa]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name = ?",
|
||||
|
|
@ -662,34 +676,175 @@
|
|||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[ge-bb]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name >= ?",
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 6,
|
||||
'name': 'cccccccccccc_ccc',
|
||||
'distance': 94.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 5,
|
||||
'name': 'cccc',
|
||||
'distance': 95.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 4,
|
||||
'name': 'bbbbbbbbbbbb_bbb',
|
||||
'distance': 96.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 3,
|
||||
'name': 'bbbb',
|
||||
'distance': 97.0,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[ge-bbbb]
|
||||
dict({
|
||||
'error': 'OperationalError',
|
||||
'message': 'Could not filter metadata fields',
|
||||
OrderedDict({
|
||||
'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name >= ?",
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 6,
|
||||
'name': 'cccccccccccc_ccc',
|
||||
'distance': 94.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 5,
|
||||
'name': 'cccc',
|
||||
'distance': 95.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 4,
|
||||
'name': 'bbbbbbbbbbbb_bbb',
|
||||
'distance': 96.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 3,
|
||||
'name': 'bbbb',
|
||||
'distance': 97.0,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[ge-bbbbbb]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name >= ?",
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 6,
|
||||
'name': 'cccccccccccc_ccc',
|
||||
'distance': 94.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 5,
|
||||
'name': 'cccc',
|
||||
'distance': 95.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 4,
|
||||
'name': 'bbbbbbbbbbbb_bbb',
|
||||
'distance': 96.0,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[ge-bbbbbbbbbbbb_aaa]
|
||||
dict({
|
||||
'error': 'OperationalError',
|
||||
'message': 'Could not filter metadata fields',
|
||||
OrderedDict({
|
||||
'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name >= ?",
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 6,
|
||||
'name': 'cccccccccccc_ccc',
|
||||
'distance': 94.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 5,
|
||||
'name': 'cccc',
|
||||
'distance': 95.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 4,
|
||||
'name': 'bbbbbbbbbbbb_bbb',
|
||||
'distance': 96.0,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[ge-bbbbbbbbbbbb_bbb]
|
||||
dict({
|
||||
'error': 'OperationalError',
|
||||
'message': 'Could not filter metadata fields',
|
||||
OrderedDict({
|
||||
'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name >= ?",
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 6,
|
||||
'name': 'cccccccccccc_ccc',
|
||||
'distance': 94.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 5,
|
||||
'name': 'cccc',
|
||||
'distance': 95.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 4,
|
||||
'name': 'bbbbbbbbbbbb_bbb',
|
||||
'distance': 96.0,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[ge-bbbbbbbbbbbb_ccc]
|
||||
dict({
|
||||
'error': 'OperationalError',
|
||||
'message': 'Could not filter metadata fields',
|
||||
OrderedDict({
|
||||
'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name >= ?",
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 6,
|
||||
'name': 'cccccccccccc_ccc',
|
||||
'distance': 94.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 5,
|
||||
'name': 'cccc',
|
||||
'distance': 95.0,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[ge-longlonglonglonglonglonglong]
|
||||
dict({
|
||||
'error': 'OperationalError',
|
||||
'message': 'Could not filter metadata fields',
|
||||
OrderedDict({
|
||||
'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name >= ?",
|
||||
'rows': list([
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[gt-bb]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name > ?",
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 6,
|
||||
'name': 'cccccccccccc_ccc',
|
||||
'distance': 94.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 5,
|
||||
'name': 'cccc',
|
||||
'distance': 95.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 4,
|
||||
'name': 'bbbbbbbbbbbb_bbb',
|
||||
'distance': 96.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 3,
|
||||
'name': 'bbbb',
|
||||
'distance': 97.0,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[gt-bbbb]
|
||||
|
|
@ -714,6 +869,28 @@
|
|||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[gt-bbbbbb]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name > ?",
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 6,
|
||||
'name': 'cccccccccccc_ccc',
|
||||
'distance': 94.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 5,
|
||||
'name': 'cccc',
|
||||
'distance': 95.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 4,
|
||||
'name': 'bbbbbbbbbbbb_bbb',
|
||||
'distance': 96.0,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[gt-bbbbbbbbbbbb_aaa]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name > ?",
|
||||
|
|
@ -777,12 +954,24 @@
|
|||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[le-bb]
|
||||
dict({
|
||||
'error': 'OperationalError',
|
||||
'message': 'Could not filter metadata fields',
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[le-bbbb]
|
||||
dict({
|
||||
'error': 'OperationalError',
|
||||
'message': 'Could not filter metadata fields',
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[le-bbbbbb]
|
||||
dict({
|
||||
'error': 'OperationalError',
|
||||
'message': 'Could not filter metadata fields',
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[le-bbbbbbbbbbbb_aaa]
|
||||
dict({
|
||||
'error': 'OperationalError',
|
||||
|
|
@ -807,12 +996,24 @@
|
|||
'message': 'Could not filter metadata fields',
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[lt-bb]
|
||||
dict({
|
||||
'error': 'OperationalError',
|
||||
'message': 'Could not filter metadata fields',
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[lt-bbbb]
|
||||
dict({
|
||||
'error': 'OperationalError',
|
||||
'message': 'Could not filter metadata fields',
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[lt-bbbbbb]
|
||||
dict({
|
||||
'error': 'OperationalError',
|
||||
'message': 'Could not filter metadata fields',
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[lt-bbbbbbbbbbbb_aaa]
|
||||
dict({
|
||||
'error': 'OperationalError',
|
||||
|
|
@ -837,6 +1038,38 @@
|
|||
'message': 'Could not filter metadata fields',
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[ne-bb]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name != ?",
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 6,
|
||||
'name': 'cccccccccccc_ccc',
|
||||
'distance': 94.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 5,
|
||||
'name': 'cccc',
|
||||
'distance': 95.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 4,
|
||||
'name': 'bbbbbbbbbbbb_bbb',
|
||||
'distance': 96.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 3,
|
||||
'name': 'bbbb',
|
||||
'distance': 97.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 2,
|
||||
'name': 'aaaaaaaaaaaa_aaa',
|
||||
'distance': 98.0,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[ne-bbbb]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name != ?",
|
||||
|
|
@ -869,6 +1102,38 @@
|
|||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[ne-bbbbbb]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name != ?",
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 6,
|
||||
'name': 'cccccccccccc_ccc',
|
||||
'distance': 94.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 5,
|
||||
'name': 'cccc',
|
||||
'distance': 95.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 4,
|
||||
'name': 'bbbbbbbbbbbb_bbb',
|
||||
'distance': 96.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 3,
|
||||
'name': 'bbbb',
|
||||
'distance': 97.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 2,
|
||||
'name': 'aaaaaaaaaaaa_aaa',
|
||||
'distance': 98.0,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_long_text_knn[ne-bbbbbbbbbbbb_aaa]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name != ?",
|
||||
|
|
|
|||
|
|
@ -147,6 +147,8 @@ def test_long_text_knn(db, snapshot):
|
|||
|
||||
tests = [
|
||||
"bbbb",
|
||||
"bb",
|
||||
"bbbbbb",
|
||||
"bbbbbbbbbbbb_bbb",
|
||||
"bbbbbbbbbbbb_aaa",
|
||||
"bbbbbbbbbbbb_ccc",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue