caching bugfix for psi2 computations

This commit is contained in:
Max Zwiessele 2013-07-18 15:39:58 +01:00
parent dcec9d2a25
commit 9b161b7440
2 changed files with 76 additions and 79 deletions

View file

@ -4,7 +4,6 @@
from kernpart import Kernpart from kernpart import Kernpart
import numpy as np import numpy as np
import hashlib
from scipy import weave from scipy import weave
from ...util.linalg import tdot from ...util.linalg import tdot
from ...util.misc import fast_array_equal from ...util.misc import fast_array_equal
@ -225,7 +224,7 @@ class RBF(Kernpart):
def _K_computations(self, X, X2): def _K_computations(self, X, X2):
if not (fast_array_equal(X, self._X) and fast_array_equal(X2, self._X2) and fast_array_equal(self._params , self._get_params())): if not (fast_array_equal(X, self._X) and fast_array_equal(X2, self._X2) and fast_array_equal(self._params , self._get_params())):
self._X = X.copy() self._X = X.copy()
self._params == self._get_params().copy() self._params = self._get_params().copy()
if X2 is None: if X2 is None:
self._X2 = None self._X2 = None
X = X / self.lengthscale X = X / self.lengthscale
@ -245,7 +244,6 @@ class RBF(Kernpart):
self._psi2_Zhat = 0.5 * (Z[:, None, :] + Z[None, :, :]) # M,M,Q self._psi2_Zhat = 0.5 * (Z[:, None, :] + Z[None, :, :]) # M,M,Q
self._psi2_Zdist = 0.5 * (Z[:, None, :] - Z[None, :, :]) # M,M,Q self._psi2_Zdist = 0.5 * (Z[:, None, :] - Z[None, :, :]) # M,M,Q
self._psi2_Zdist_sq = np.square(self._psi2_Zdist / self.lengthscale) # M,M,Q self._psi2_Zdist_sq = np.square(self._psi2_Zdist / self.lengthscale) # M,M,Q
self._Z = Z
if not (fast_array_equal(Z, self._Z) and fast_array_equal(mu, self._mu) and fast_array_equal(S, self._S)): if not (fast_array_equal(Z, self._Z) and fast_array_equal(mu, self._mu) and fast_array_equal(S, self._S)):
# something's changed. recompute EVERYTHING # something's changed. recompute EVERYTHING

View file

@ -217,7 +217,7 @@ class RBFInv(RBF):
def _K_computations(self, X, X2): def _K_computations(self, X, X2):
if not (np.array_equal(X, self._X) and np.array_equal(X2, self._X2) and np.array_equal(self._params , self._get_params())): if not (np.array_equal(X, self._X) and np.array_equal(X2, self._X2) and np.array_equal(self._params , self._get_params())):
self._X = X.copy() self._X = X.copy()
self._params == self._get_params().copy() self._params = self._get_params().copy()
if X2 is None: if X2 is None:
self._X2 = None self._X2 = None
X = X * self.inv_lengthscale X = X * self.inv_lengthscale
@ -237,7 +237,6 @@ class RBFInv(RBF):
self._psi2_Zhat = 0.5 * (Z[:, None, :] + Z[None, :, :]) # M,M,Q self._psi2_Zhat = 0.5 * (Z[:, None, :] + Z[None, :, :]) # M,M,Q
self._psi2_Zdist = 0.5 * (Z[:, None, :] - Z[None, :, :]) # M,M,Q self._psi2_Zdist = 0.5 * (Z[:, None, :] - Z[None, :, :]) # M,M,Q
self._psi2_Zdist_sq = np.square(self._psi2_Zdist * self.inv_lengthscale) # M,M,Q self._psi2_Zdist_sq = np.square(self._psi2_Zdist * self.inv_lengthscale) # M,M,Q
self._Z = Z
if not (np.array_equal(Z, self._Z) and np.array_equal(mu, self._mu) and np.array_equal(S, self._S)): if not (np.array_equal(Z, self._Z) and np.array_equal(mu, self._mu) and np.array_equal(S, self._S)):
# something's changed. recompute EVERYTHING # something's changed. recompute EVERYTHING