From 2a8593a755e4a4245433f714b4b4b6d3d77ca51d Mon Sep 17 00:00:00 2001 From: Alex Garcia Date: Mon, 5 Aug 2024 16:19:10 -0700 Subject: [PATCH] avoid referencing npt when numpy isn't installed --- bindings/python/extra_init.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/bindings/python/extra_init.py b/bindings/python/extra_init.py index 6c332ff..267bc41 100644 --- a/bindings/python/extra_init.py +++ b/bindings/python/extra_init.py @@ -1,24 +1,22 @@ - from typing import List from struct import pack from sqlite3 import Connection + def serialize_float32(vector: List[float]) -> bytes: - """ Serializes a list of floats into the "raw bytes" format sqlite-vec expects """ + """Serializes a list of floats into the "raw bytes" format sqlite-vec expects""" return pack("%sf" % len(vector), *vector) + def serialize_int8(vector: List[int]) -> bytes: - """ Serializes a list of integers into the "raw bytes" format sqlite-vec expects """ + """Serializes a list of integers into the "raw bytes" format sqlite-vec expects""" return pack("%sb" % len(vector), *vector) -def serialize_(vector: List[int]) -> bytes: - """ Serializes a list of integers into the "raw bytes" format sqlite-vec expects """ - return pack("%sb" % len(vector), *vector) try: import numpy.typing as npt - def register_numpy(db: Connection, name:str, array: npt.NDArray): + def register_numpy(db: Connection, name: str, array: npt.NDArray): """ayoo""" ptr = array.__array_interface__["data"][0] @@ -42,5 +40,6 @@ try: ) except ImportError: - def register_numpy(db: Connection, name:str, array: npt.NDArray): - raise Exception("numpy package is required for register_np") + + def register_numpy(db: Connection, name: str, array): + raise Exception("numpy package is required for register_numpy")