mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 00:36:56 +02:00
old: test-knn-constraints
This commit is contained in:
parent
34c49da26c
commit
bbb3238209
3 changed files with 356 additions and 1 deletions
2
test.sql
2
test.sql
|
|
@ -1,5 +1,5 @@
|
|||
|
||||
.load dist/vec0main
|
||||
.load dist/vec0
|
||||
.bail on
|
||||
|
||||
.mode qbox
|
||||
|
|
|
|||
273
tests/__snapshots__/test-knn-distance-constraints.ambr
Normal file
273
tests/__snapshots__/test-knn-distance-constraints.ambr
Normal file
|
|
@ -0,0 +1,273 @@
|
|||
# serializer version: 1
|
||||
# name: test_normal
|
||||
OrderedDict({
|
||||
'sql': 'SELECT * FROM v',
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 1,
|
||||
'embedding': b'\x00\x00\x80?',
|
||||
'is_odd': 1,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 2,
|
||||
'embedding': b'\x00\x00\x00@',
|
||||
'is_odd': 0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 3,
|
||||
'embedding': b'\x00\x00@@',
|
||||
'is_odd': 1,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 4,
|
||||
'embedding': b'\x00\x00\x80@',
|
||||
'is_odd': 0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 5,
|
||||
'embedding': b'\x00\x00\xa0@',
|
||||
'is_odd': 1,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 6,
|
||||
'embedding': b'\x00\x00\xc0@',
|
||||
'is_odd': 0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 7,
|
||||
'embedding': b'\x00\x00\xe0@',
|
||||
'is_odd': 1,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 8,
|
||||
'embedding': b'\x00\x00\x00A',
|
||||
'is_odd': 0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 9,
|
||||
'embedding': b'\x00\x00\x10A',
|
||||
'is_odd': 1,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 10,
|
||||
'embedding': b'\x00\x00 A',
|
||||
'is_odd': 0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 11,
|
||||
'embedding': b'\x00\x000A',
|
||||
'is_odd': 1,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 12,
|
||||
'embedding': b'\x00\x00@A',
|
||||
'is_odd': 0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 13,
|
||||
'embedding': b'\x00\x00PA',
|
||||
'is_odd': 1,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 14,
|
||||
'embedding': b'\x00\x00`A',
|
||||
'is_odd': 0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 15,
|
||||
'embedding': b'\x00\x00pA',
|
||||
'is_odd': 1,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 16,
|
||||
'embedding': b'\x00\x00\x80A',
|
||||
'is_odd': 0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 17,
|
||||
'embedding': b'\x00\x00\x88A',
|
||||
'is_odd': 1,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_normal.1
|
||||
OrderedDict({
|
||||
'sql': 'select rowid, distance from v where embedding match ? and k = ? ',
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 1,
|
||||
'distance': 0.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 2,
|
||||
'distance': 1.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 3,
|
||||
'distance': 2.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 4,
|
||||
'distance': 3.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 5,
|
||||
'distance': 4.0,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_normal.2
|
||||
OrderedDict({
|
||||
'sql': 'select rowid, distance from v where embedding match ? and k = ? AND distance > 5',
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 7,
|
||||
'distance': 6.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 8,
|
||||
'distance': 7.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 9,
|
||||
'distance': 8.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 10,
|
||||
'distance': 9.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 11,
|
||||
'distance': 10.0,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_normal.3
|
||||
OrderedDict({
|
||||
'sql': 'select rowid, distance from v where embedding match ? and k = ? AND distance >= 5',
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 6,
|
||||
'distance': 5.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 7,
|
||||
'distance': 6.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 8,
|
||||
'distance': 7.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 9,
|
||||
'distance': 8.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 10,
|
||||
'distance': 9.0,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_normal.4
|
||||
OrderedDict({
|
||||
'sql': 'select rowid, distance from v where embedding match ? and k = ? AND distance < 3',
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 1,
|
||||
'distance': 0.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 2,
|
||||
'distance': 1.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 3,
|
||||
'distance': 2.0,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_normal.5
|
||||
OrderedDict({
|
||||
'sql': 'select rowid, distance from v where embedding match ? and k = ? AND distance <= 3',
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 1,
|
||||
'distance': 0.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 2,
|
||||
'distance': 1.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 3,
|
||||
'distance': 2.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 4,
|
||||
'distance': 3.0,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_normal.6
|
||||
OrderedDict({
|
||||
'sql': 'select rowid, distance from v where embedding match ? and k = ? AND distance > 7 AND distance <= 10',
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 9,
|
||||
'distance': 8.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 10,
|
||||
'distance': 9.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 11,
|
||||
'distance': 10.0,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_normal.7
|
||||
OrderedDict({
|
||||
'sql': 'select rowid, distance from v where embedding match ? and k = ? AND distance BETWEEN 7 AND 10',
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 8,
|
||||
'distance': 7.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 9,
|
||||
'distance': 8.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 10,
|
||||
'distance': 9.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 11,
|
||||
'distance': 10.0,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_normal.8
|
||||
OrderedDict({
|
||||
'sql': 'select rowid, distance from v where embedding match ? and k = ? AND is_odd == TRUE AND distance BETWEEN 7 AND 10',
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 9,
|
||||
'distance': 8.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 11,
|
||||
'distance': 10.0,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
82
tests/test-knn-distance-constraints.py
Normal file
82
tests/test-knn-distance-constraints.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
import sqlite3
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def test_normal(db, snapshot):
|
||||
db.execute("create virtual table v using vec0(embedding float[1], is_odd boolean, chunk_size=8)")
|
||||
db.executemany(
|
||||
"insert into v(rowid, is_odd, embedding) values (?1, ?1 % 2, ?2)",
|
||||
[
|
||||
[1, "[1]"],
|
||||
[2, "[2]"],
|
||||
[3, "[3]"],
|
||||
[4, "[4]"],
|
||||
[5, "[5]"],
|
||||
[6, "[6]"],
|
||||
[7, "[7]"],
|
||||
[8, "[8]"],
|
||||
[9, "[9]"],
|
||||
[10, "[10]"],
|
||||
[11, "[11]"],
|
||||
[12, "[12]"],
|
||||
[13, "[13]"],
|
||||
[14, "[14]"],
|
||||
[15, "[15]"],
|
||||
[16, "[16]"],
|
||||
[17, "[17]"],
|
||||
],
|
||||
)
|
||||
assert exec(db,"SELECT * FROM v") == snapshot()
|
||||
|
||||
BASE_KNN = "select rowid, distance from v where embedding match ? and k = ? "
|
||||
assert exec(db, BASE_KNN, ["[1]", 5]) == snapshot()
|
||||
assert exec(db, BASE_KNN + "AND distance > 5", ["[1]", 5]) == snapshot()
|
||||
assert exec(db, BASE_KNN + "AND distance >= 5", ["[1]", 5]) == snapshot()
|
||||
assert exec(db, BASE_KNN + "AND distance < 3", ["[1]", 5]) == snapshot()
|
||||
assert exec(db, BASE_KNN + "AND distance <= 3", ["[1]", 5]) == snapshot()
|
||||
assert exec(db, BASE_KNN + "AND distance > 7 AND distance <= 10", ["[1]", 5]) == snapshot()
|
||||
assert exec(db, BASE_KNN + "AND distance BETWEEN 7 AND 10", ["[1]", 5]) == snapshot()
|
||||
assert exec(db, BASE_KNN + "AND is_odd == TRUE AND distance BETWEEN 7 AND 10", ["[1]", 5]) == snapshot()
|
||||
|
||||
|
||||
class Row:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr()
|
||||
|
||||
|
||||
def exec(db, sql, parameters=[]):
|
||||
try:
|
||||
rows = db.execute(sql, parameters).fetchall()
|
||||
except (sqlite3.OperationalError, sqlite3.DatabaseError) as e:
|
||||
return {
|
||||
"error": e.__class__.__name__,
|
||||
"message": str(e),
|
||||
}
|
||||
a = []
|
||||
for row in rows:
|
||||
o = OrderedDict()
|
||||
for k in row.keys():
|
||||
o[k] = row[k]
|
||||
a.append(o)
|
||||
result = OrderedDict()
|
||||
result["sql"] = sql
|
||||
result["rows"] = a
|
||||
return result
|
||||
|
||||
|
||||
def vec0_shadow_table_contents(db, v):
|
||||
shadow_tables = [
|
||||
row[0]
|
||||
for row in db.execute(
|
||||
"select name from sqlite_master where name like ? order by 1", [f"{v}_%"]
|
||||
).fetchall()
|
||||
]
|
||||
o = {}
|
||||
for shadow_table in shadow_tables:
|
||||
if shadow_table.endswith("_info"):
|
||||
continue
|
||||
o[shadow_table] = exec(db, f"select * from {shadow_table}")
|
||||
return o
|
||||
Loading…
Add table
Add a link
Reference in a new issue