mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-26 01:06:27 +02:00
52 lines
1.4 KiB
Python
52 lines
1.4 KiB
Python
|
|
import numpy as np
|
||
|
|
import numpy.typing as npt
|
||
|
|
import time
|
||
|
|
|
||
|
|
def cosine_similarity(
|
||
|
|
vec: npt.NDArray[np.float32], mat: npt.NDArray[np.float32], do_norm: bool = True
|
||
|
|
) -> npt.NDArray[np.float32]:
|
||
|
|
sim = vec @ mat.T
|
||
|
|
if do_norm:
|
||
|
|
sim /= np.linalg.norm(vec) * np.linalg.norm(mat, axis=1)
|
||
|
|
return sim
|
||
|
|
|
||
|
|
|
||
|
|
def topk(
|
||
|
|
vec: npt.NDArray[np.float32],
|
||
|
|
mat: npt.NDArray[np.float32],
|
||
|
|
k: int = 5,
|
||
|
|
do_norm: bool = True,
|
||
|
|
) -> tuple[npt.NDArray[np.int32], npt.NDArray[np.float32]]:
|
||
|
|
sim = cosine_similarity(vec, mat, do_norm=do_norm)
|
||
|
|
# Rather than sorting all similarities and taking the top K, it's faster to
|
||
|
|
# argpartition and then just sort the top K.
|
||
|
|
# The difference is O(N logN) vs O(N + k logk)
|
||
|
|
indices = np.argpartition(-sim, kth=k)[:k]
|
||
|
|
top_indices = np.argsort(-sim[indices])
|
||
|
|
return indices[top_indices], sim[top_indices]
|
||
|
|
|
||
|
|
|
||
|
|
def ivecs_read(fname):
|
||
|
|
a = np.fromfile(fname, dtype="int32")
|
||
|
|
d = a[0]
|
||
|
|
return a.reshape(-1, d + 1)[:, 1:].copy()
|
||
|
|
|
||
|
|
|
||
|
|
def fvecs_read(fname):
|
||
|
|
return ivecs_read(fname).view("float32")
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
base = fvecs_read("../../sift/sift_base.fvecs")
|
||
|
|
queries = fvecs_read("../../sift/sift_query.fvecs")
|
||
|
|
k = 20
|
||
|
|
times = []
|
||
|
|
results = []
|
||
|
|
for idx, q in enumerate(queries[0:20]):
|
||
|
|
t0 = time.time()
|
||
|
|
result = topk(q, base, k=k)
|
||
|
|
results.append(result)
|
||
|
|
times.append(time.time() - t0)
|
||
|
|
print(np.__version__)
|
||
|
|
print(np.mean(times))
|