mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 16:56:27 +02:00
avoid referencing npt when numpy isn't installed
This commit is contained in:
parent
daf4e05491
commit
2a8593a755
1 changed files with 8 additions and 9 deletions
|
|
@ -1,24 +1,22 @@
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
from struct import pack
|
from struct import pack
|
||||||
from sqlite3 import Connection
|
from sqlite3 import Connection
|
||||||
|
|
||||||
|
|
||||||
def serialize_float32(vector: List[float]) -> bytes:
|
def serialize_float32(vector: List[float]) -> bytes:
|
||||||
""" Serializes a list of floats into the "raw bytes" format sqlite-vec expects """
|
"""Serializes a list of floats into the "raw bytes" format sqlite-vec expects"""
|
||||||
return pack("%sf" % len(vector), *vector)
|
return pack("%sf" % len(vector), *vector)
|
||||||
|
|
||||||
|
|
||||||
def serialize_int8(vector: List[int]) -> bytes:
|
def serialize_int8(vector: List[int]) -> bytes:
|
||||||
""" Serializes a list of integers into the "raw bytes" format sqlite-vec expects """
|
"""Serializes a list of integers into the "raw bytes" format sqlite-vec expects"""
|
||||||
return pack("%sb" % len(vector), *vector)
|
return pack("%sb" % len(vector), *vector)
|
||||||
|
|
||||||
def serialize_(vector: List[int]) -> bytes:
|
|
||||||
""" Serializes a list of integers into the "raw bytes" format sqlite-vec expects """
|
|
||||||
return pack("%sb" % len(vector), *vector)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
|
||||||
def register_numpy(db: Connection, name:str, array: npt.NDArray):
|
def register_numpy(db: Connection, name: str, array: npt.NDArray):
|
||||||
"""ayoo"""
|
"""ayoo"""
|
||||||
|
|
||||||
ptr = array.__array_interface__["data"][0]
|
ptr = array.__array_interface__["data"][0]
|
||||||
|
|
@ -42,5 +40,6 @@ try:
|
||||||
)
|
)
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
def register_numpy(db: Connection, name:str, array: npt.NDArray):
|
|
||||||
raise Exception("numpy package is required for register_np")
|
def register_numpy(db: Connection, name: str, array):
|
||||||
|
raise Exception("numpy package is required for register_numpy")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue