avoid referencing npt when numpy isn't installed

This commit is contained in:
Alex Garcia 2024-08-05 16:19:10 -07:00
parent daf4e05491
commit 2a8593a755

View file

@ -1,24 +1,22 @@
from typing import List from typing import List
from struct import pack from struct import pack
from sqlite3 import Connection from sqlite3 import Connection
def serialize_float32(vector: List[float]) -> bytes: def serialize_float32(vector: List[float]) -> bytes:
""" Serializes a list of floats into the "raw bytes" format sqlite-vec expects """ """Serializes a list of floats into the "raw bytes" format sqlite-vec expects"""
return pack("%sf" % len(vector), *vector) return pack("%sf" % len(vector), *vector)
def serialize_int8(vector: List[int]) -> bytes: def serialize_int8(vector: List[int]) -> bytes:
""" Serializes a list of integers into the "raw bytes" format sqlite-vec expects """ """Serializes a list of integers into the "raw bytes" format sqlite-vec expects"""
return pack("%sb" % len(vector), *vector) return pack("%sb" % len(vector), *vector)
def serialize_(vector: List[int]) -> bytes:
""" Serializes a list of integers into the "raw bytes" format sqlite-vec expects """
return pack("%sb" % len(vector), *vector)
try: try:
import numpy.typing as npt import numpy.typing as npt
def register_numpy(db: Connection, name:str, array: npt.NDArray): def register_numpy(db: Connection, name: str, array: npt.NDArray):
"""ayoo""" """ayoo"""
ptr = array.__array_interface__["data"][0] ptr = array.__array_interface__["data"][0]
@ -42,5 +40,6 @@ try:
) )
except ImportError: except ImportError:
def register_numpy(db: Connection, name:str, array: npt.NDArray):
raise Exception("numpy package is required for register_np") def register_numpy(db: Connection, name: str, array):
raise Exception("numpy package is required for register_numpy")