init pass vec_min vec_max vec_avg

This commit is contained in:
Alex Garcia 2024-06-12 00:10:00 -07:00
parent 6875f7649c
commit 44dcb3b391
3 changed files with 333 additions and 0 deletions

View file

@ -86,6 +86,7 @@ def spread_args(args):
FUNCTIONS = [
"vec_add",
"vec_avg",
"vec_bit",
"vec_debug",
"vec_distance_cosine",
@ -94,6 +95,8 @@ FUNCTIONS = [
"vec_f32",
"vec_int8",
"vec_length",
"vec_max",
"vec_min",
"vec_normalize",
"vec_quantize_binary",
"vec_quantize_i8",
@ -459,6 +462,40 @@ def test_vec_sub():
):
vec_sub(_int8([2]), _f32([1]), a="vec_int8(?)")
def test_vec_min():
def vec_min(values, wrap="(vec_f32(?))"):
return db.execute(
"select vec_min(column1) from (values {})".format(", ".join([wrap] * len(values))), values
).fetchone()[0]
assert vec_min(["[1]", "[2]"]) == _f32([1])
assert vec_min(["[1,2,3,4]", "[-5,-6,-7,-8]"]) == _f32([-5,-6,-7,-8])
# TODO: int8 tests, block binary vectors, overflowing
with pytest.raises(
sqlite3.OperationalError,
match=re.escape("vec_min(): vector dimensions do not match."),
):
vec_min(["[1]", "[2,3]"])
def test_vec_max():
def vec_max(values, wrap="(vec_f32(?))"):
return db.execute(
"select vec_max(column1) from (values {})".format(", ".join([wrap] * len(values))), values
).fetchone()[0]
assert vec_max(["[1]", "[2]"]) == _f32([1])
assert vec_max(["[1,2,3,4]", "[-5,-6,-7,-8]"]) == _f32([-5,-6,-7,-8])
# TODO: int8 tests, block binary vectors, overflowing
with pytest.raises(
sqlite3.OperationalError,
match=re.escape("vec_min(): vector dimensions do not match."),
):
vec_min(["[1]", "[2,3]"])
def test_vec_to_json():
vec_to_json = lambda *args, input="?": db.execute(