diff --git a/GPy/inference/latent_function_inference/var_dtc_parallel.py b/GPy/inference/latent_function_inference/var_dtc_parallel.py index 4b29b16a..87236e2a 100644 --- a/GPy/inference/latent_function_inference/var_dtc_parallel.py +++ b/GPy/inference/latent_function_inference/var_dtc_parallel.py @@ -302,18 +302,19 @@ def update_gradients(model): while not isEnd: isEnd, n_range, grad_dict = model.inference_method.inference_minibatch(model.kern, model.X, model.Z, model.likelihood, model.Y) if isinstance(model.X, VariationalPosterior): + X_slice = model.X[n_range[0]:n_range[1]] #gradients w.r.t. kernel - model.kern.update_gradients_expectations(variational_posterior=model.X[n_range[0]:n_range[1]], Z=model.Z, dL_dpsi0=grad_dict['dL_dpsi0'], dL_dpsi1=grad_dict['dL_dpsi1'], dL_dpsi2=grad_dict['dL_dpsi2']) + model.kern.update_gradients_expectations(variational_posterior=X_slice, Z=model.Z, dL_dpsi0=grad_dict['dL_dpsi0'], dL_dpsi1=grad_dict['dL_dpsi1'], dL_dpsi2=grad_dict['dL_dpsi2']) kern_grad += model.kern.gradient #gradients w.r.t. Z model.Z.gradient[:,model.kern.active_dims] += model.kern.gradients_Z_expectations( - grad_dict['dL_dpsi1'], grad_dict['dL_dpsi2'], Z=model.Z, variational_posterior=model.X[n_range[0]:n_range[1]]) + grad_dict['dL_dpsi1'], grad_dict['dL_dpsi2'], Z=model.Z, variational_posterior=X_slice) #gradients w.r.t. posterior parameters of X - X_grad = model.kern.gradients_qX_expectations(variational_posterior=model.X[n_range[0]:n_range[1]], Z=model.Z, dL_dpsi0=grad_dict['dL_dpsi0'], dL_dpsi1=grad_dict['dL_dpsi1'], dL_dpsi2=grad_dict['dL_dpsi2']) - model.set_X_gradients(model.X[n_range[0]:n_range[1]], X_grad) + X_grad = model.kern.gradients_qX_expectations(variational_posterior=X_slice, Z=model.Z, dL_dpsi0=grad_dict['dL_dpsi0'], dL_dpsi1=grad_dict['dL_dpsi1'], dL_dpsi2=grad_dict['dL_dpsi2']) + model.set_X_gradients(X_slice, X_grad) if het_noise: dL_dthetaL[n_range[0]:n_range[1]] = grad_dict['dL_dthetaL'] diff --git a/GPy/kern/_src/psi_comp/ssrbf_psi_gpucomp.py b/GPy/kern/_src/psi_comp/ssrbf_psi_gpucomp.py index 14414222..a695d14e 100644 --- a/GPy/kern/_src/psi_comp/ssrbf_psi_gpucomp.py +++ b/GPy/kern/_src/psi_comp/ssrbf_psi_gpucomp.py @@ -328,7 +328,7 @@ class PSICOMP_SSRBF(object): self.gpuCacheAll = None self.gpuCache = None - @Cache_this(limit=1) + @Cache_this(limit=1,ignore_args=(0,)) def psicomputations(self, variance, lengthscale, Z, mu, S, gamma): """Compute Psi statitsitcs""" if isinstance(lengthscale, np.ndarray) and len(lengthscale)>1: @@ -373,7 +373,7 @@ class PSICOMP_SSRBF(object): return psi0_gpu, psi1_gpu, psi2_gpu - @Cache_this(limit=1) + @Cache_this(limit=1,ignore_args=(0,)) def _psiDercomputations(self, variance, lengthscale, Z, mu, S, gamma): """Compute the derivatives w.r.t. Psi statistics""" N, M, Q = mu.shape[0],Z.shape[0], mu.shape[1] diff --git a/GPy/util/caching.py b/GPy/util/caching.py index 282c9f8c..676c3ab8 100644 --- a/GPy/util/caching.py +++ b/GPy/util/caching.py @@ -66,6 +66,7 @@ class Cacher(object): #first make sure the depth limit isn't exceeded if len(self.cached_inputs) == self.limit: args_ = self.cached_inputs.pop(0) + args_ = [a for i,a in enumerate(args_) if i not in self.ignore_args and i not in self.force_kwargs] [a.remove_observer(self, self.on_cache_changed) for a in args_ if a is not None] self.inputs_changed.pop(0) self.cached_outputs.pop(0)