mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 12:32:40 +02:00
caching bugfix for psi2 computations
This commit is contained in:
parent
dcec9d2a25
commit
9b161b7440
2 changed files with 76 additions and 79 deletions
|
|
@ -4,7 +4,6 @@
|
||||||
|
|
||||||
from kernpart import Kernpart
|
from kernpart import Kernpart
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import hashlib
|
|
||||||
from scipy import weave
|
from scipy import weave
|
||||||
from ...util.linalg import tdot
|
from ...util.linalg import tdot
|
||||||
from ...util.misc import fast_array_equal
|
from ...util.misc import fast_array_equal
|
||||||
|
|
@ -225,7 +224,7 @@ class RBF(Kernpart):
|
||||||
def _K_computations(self, X, X2):
|
def _K_computations(self, X, X2):
|
||||||
if not (fast_array_equal(X, self._X) and fast_array_equal(X2, self._X2) and fast_array_equal(self._params , self._get_params())):
|
if not (fast_array_equal(X, self._X) and fast_array_equal(X2, self._X2) and fast_array_equal(self._params , self._get_params())):
|
||||||
self._X = X.copy()
|
self._X = X.copy()
|
||||||
self._params == self._get_params().copy()
|
self._params = self._get_params().copy()
|
||||||
if X2 is None:
|
if X2 is None:
|
||||||
self._X2 = None
|
self._X2 = None
|
||||||
X = X / self.lengthscale
|
X = X / self.lengthscale
|
||||||
|
|
@ -245,7 +244,6 @@ class RBF(Kernpart):
|
||||||
self._psi2_Zhat = 0.5 * (Z[:, None, :] + Z[None, :, :]) # M,M,Q
|
self._psi2_Zhat = 0.5 * (Z[:, None, :] + Z[None, :, :]) # M,M,Q
|
||||||
self._psi2_Zdist = 0.5 * (Z[:, None, :] - Z[None, :, :]) # M,M,Q
|
self._psi2_Zdist = 0.5 * (Z[:, None, :] - Z[None, :, :]) # M,M,Q
|
||||||
self._psi2_Zdist_sq = np.square(self._psi2_Zdist / self.lengthscale) # M,M,Q
|
self._psi2_Zdist_sq = np.square(self._psi2_Zdist / self.lengthscale) # M,M,Q
|
||||||
self._Z = Z
|
|
||||||
|
|
||||||
if not (fast_array_equal(Z, self._Z) and fast_array_equal(mu, self._mu) and fast_array_equal(S, self._S)):
|
if not (fast_array_equal(Z, self._Z) and fast_array_equal(mu, self._mu) and fast_array_equal(S, self._S)):
|
||||||
# something's changed. recompute EVERYTHING
|
# something's changed. recompute EVERYTHING
|
||||||
|
|
|
||||||
|
|
@ -217,7 +217,7 @@ class RBFInv(RBF):
|
||||||
def _K_computations(self, X, X2):
|
def _K_computations(self, X, X2):
|
||||||
if not (np.array_equal(X, self._X) and np.array_equal(X2, self._X2) and np.array_equal(self._params , self._get_params())):
|
if not (np.array_equal(X, self._X) and np.array_equal(X2, self._X2) and np.array_equal(self._params , self._get_params())):
|
||||||
self._X = X.copy()
|
self._X = X.copy()
|
||||||
self._params == self._get_params().copy()
|
self._params = self._get_params().copy()
|
||||||
if X2 is None:
|
if X2 is None:
|
||||||
self._X2 = None
|
self._X2 = None
|
||||||
X = X * self.inv_lengthscale
|
X = X * self.inv_lengthscale
|
||||||
|
|
@ -237,7 +237,6 @@ class RBFInv(RBF):
|
||||||
self._psi2_Zhat = 0.5 * (Z[:, None, :] + Z[None, :, :]) # M,M,Q
|
self._psi2_Zhat = 0.5 * (Z[:, None, :] + Z[None, :, :]) # M,M,Q
|
||||||
self._psi2_Zdist = 0.5 * (Z[:, None, :] - Z[None, :, :]) # M,M,Q
|
self._psi2_Zdist = 0.5 * (Z[:, None, :] - Z[None, :, :]) # M,M,Q
|
||||||
self._psi2_Zdist_sq = np.square(self._psi2_Zdist * self.inv_lengthscale) # M,M,Q
|
self._psi2_Zdist_sq = np.square(self._psi2_Zdist * self.inv_lengthscale) # M,M,Q
|
||||||
self._Z = Z
|
|
||||||
|
|
||||||
if not (np.array_equal(Z, self._Z) and np.array_equal(mu, self._mu) and np.array_equal(S, self._S)):
|
if not (np.array_equal(Z, self._Z) and np.array_equal(mu, self._mu) and np.array_equal(S, self._S)):
|
||||||
# something's changed. recompute EVERYTHING
|
# something's changed. recompute EVERYTHING
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue