mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 08:46:49 +02:00
fixes and tests
This commit is contained in:
parent
4039328eda
commit
a657b3a216
4 changed files with 121 additions and 32 deletions
|
|
@ -132,29 +132,45 @@ def tests_command(file_path):
|
|||
conditions = test["conditions"]
|
||||
expected_closest_ids = test["closest_ids"]
|
||||
expected_closest_scores = test["closest_scores"]
|
||||
if "or" in conditions:
|
||||
num_or_skips += 1
|
||||
continue
|
||||
|
||||
sql = "select rowid, 1 - distance as similarity from v where vector match ? and k = ?"
|
||||
params = [serialize_float32(query), len(expected_closest_ids)]
|
||||
|
||||
for condition in conditions["and"]:
|
||||
assert len(condition.keys()) == 1
|
||||
column = list(condition.keys())[0]
|
||||
assert len(list(condition[column].keys())) == 1
|
||||
condition_type = list(condition[column].keys())[0]
|
||||
if condition_type == "match":
|
||||
value = condition[column]["match"]["value"]
|
||||
sql += f" and {column} = ?"
|
||||
params.append(value)
|
||||
elif condition_type == "range":
|
||||
sql += f" and {column} between ? and ?"
|
||||
params.append(condition[column]["range"]["gt"])
|
||||
params.append(condition[column]["range"]["lt"])
|
||||
else:
|
||||
raise Exception(f"Unknown condition type: {condition_type}")
|
||||
if "and" in conditions:
|
||||
for condition in conditions["and"]:
|
||||
assert len(condition.keys()) == 1
|
||||
column = list(condition.keys())[0]
|
||||
assert len(list(condition[column].keys())) == 1
|
||||
condition_type = list(condition[column].keys())[0]
|
||||
if condition_type == "match":
|
||||
value = condition[column]["match"]["value"]
|
||||
sql += f" and {column} = ?"
|
||||
params.append(value)
|
||||
elif condition_type == "range":
|
||||
sql += f" and {column} between ? and ?"
|
||||
params.append(condition[column]["range"]["gt"])
|
||||
params.append(condition[column]["range"]["lt"])
|
||||
else:
|
||||
raise Exception(f"Unknown condition type: {condition_type}")
|
||||
elif "or" in conditions:
|
||||
column = list(conditions["or"][0].keys())[0]
|
||||
condition_type = list(conditions["or"][0][column].keys())[0]
|
||||
assert condition_type == "match"
|
||||
sql += f" and {column} in ("
|
||||
for idx, condition in enumerate(conditions["or"]):
|
||||
if condition_type == "match":
|
||||
value = condition[column]["match"]["value"]
|
||||
if idx != 0:
|
||||
sql += ","
|
||||
sql += "?"
|
||||
params.append(value)
|
||||
elif condition_type == "range":
|
||||
breakpoint()
|
||||
else:
|
||||
raise Exception(f"Unknown condition type: {condition_type}")
|
||||
sql += ")"
|
||||
|
||||
# print(sql, params[1:])
|
||||
rows = db.execute(sql, params).fetchall()
|
||||
actual_closest_ids = [row["rowid"] for row in rows]
|
||||
matches = expected_closest_ids == actual_closest_ids
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue