(much) faster comparison between arrays. Useful for kernel caching

This commit is contained in:
Nicolo Fusi 2013-07-16 17:08:40 +01:00
parent f19a26a006
commit d46b07b18d
3 changed files with 50 additions and 7 deletions

View file

@ -5,6 +5,7 @@
from kernpart import Kernpart
import numpy as np
from ...util.linalg import tdot
from ...util.misc import fast_array_equal
from scipy import weave
class Linear(Kernpart):
@ -266,7 +267,7 @@ class Linear(Kernpart):
#---------------------------------------#
def _K_computations(self, X, X2):
if not (np.array_equal(X, self._Xcache) and np.array_equal(X2, self._X2cache)):
if not (fast_array_equal(X, self._Xcache) and fast_array_equal(X2, self._X2cache)):
self._Xcache = X.copy()
if X2 is None:
self._dot_product = tdot(X)
@ -277,8 +278,8 @@ class Linear(Kernpart):
def _psi_computations(self, Z, mu, S):
# here are the "statistics" for psi1 and psi2
Zv_changed = not (np.array_equal(Z, self._Z) and np.array_equal(self.variances, self._variances))
muS_changed = not (np.array_equal(mu, self._mu) and np.array_equal(S, self._S))
Zv_changed = not (fast_array_equal(Z, self._Z) and fast_array_equal(self.variances, self._variances))
muS_changed = not (fast_array_equal(mu, self._mu) and fast_array_equal(S, self._S))
if Zv_changed:
# Z has changed, compute Z specific stuff
# self.ZZ = Z[:,None,:]*Z[None,:,:] # num_inducing,num_inducing,input_dim

View file

@ -7,6 +7,7 @@ import numpy as np
import hashlib
from scipy import weave
from ...util.linalg import tdot
from ...util.misc import fast_array_equal
class RBF(Kernpart):
"""
@ -222,7 +223,7 @@ class RBF(Kernpart):
#---------------------------------------#
def _K_computations(self, X, X2):
if not (np.array_equal(X, self._X) and np.array_equal(X2, self._X2) and np.array_equal(self._params , self._get_params())):
if not (fast_array_equal(X, self._X) and fast_array_equal(X2, self._X2) and fast_array_equal(self._params , self._get_params())):
self._X = X.copy()
self._params == self._get_params().copy()
if X2 is None:
@ -239,14 +240,14 @@ class RBF(Kernpart):
def _psi_computations(self, Z, mu, S):
# here are the "statistics" for psi1 and psi2
if not np.array_equal(Z, self._Z):
if not fast_array_equal(Z, self._Z):
#Z has changed, compute Z specific stuff
self._psi2_Zhat = 0.5*(Z[:,None,:] +Z[None,:,:]) # M,M,Q
self._psi2_Zdist = 0.5*(Z[:,None,:]-Z[None,:,:]) # M,M,Q
self._psi2_Zdist_sq = np.square(self._psi2_Zdist/self.lengthscale) # M,M,Q
self._Z = Z
if not (np.array_equal(Z, self._Z) and np.array_equal(mu, self._mu) and np.array_equal(S, self._S)):
if not (fast_array_equal(Z, self._Z) and fast_array_equal(mu, self._mu) and fast_array_equal(S, self._S)):
#something's changed. recompute EVERYTHING
#psi1