diff --git a/GPy/core/parameterization/parameter_core.py b/GPy/core/parameterization/parameter_core.py index 68140763..d54cf208 100644 --- a/GPy/core/parameterization/parameter_core.py +++ b/GPy/core/parameterization/parameter_core.py @@ -169,7 +169,7 @@ class Pickleable(object): else: pickle.dump(self, f, protocol) - #=========================================================================== + #=========================================================================== # copy and pickling #=========================================================================== def copy(self): diff --git a/GPy/core/parameterization/variational.py b/GPy/core/parameterization/variational.py index 3730baed..4b2b2e4e 100644 --- a/GPy/core/parameterization/variational.py +++ b/GPy/core/parameterization/variational.py @@ -160,7 +160,7 @@ class SpikeAndSlabPosterior(VariationalPosterior): else: return super(VariationalPrior, self).__getitem__(s) - def plot(self, *args): + def plot(self, *args, **kwargs): """ Plot latent space X in 1D: @@ -169,4 +169,4 @@ class SpikeAndSlabPosterior(VariationalPosterior): import sys assert "matplotlib" in sys.modules, "matplotlib package has not been imported." from ...plotting.matplot_dep import variational_plots - return variational_plots.plot_SpikeSlab(self,*args) + return variational_plots.plot_SpikeSlab(self,*args, **kwargs) diff --git a/GPy/inference/latent_function_inference/var_dtc_parallel.py b/GPy/inference/latent_function_inference/var_dtc_parallel.py index ef298a9b..727d7e1c 100644 --- a/GPy/inference/latent_function_inference/var_dtc_parallel.py +++ b/GPy/inference/latent_function_inference/var_dtc_parallel.py @@ -29,6 +29,7 @@ class VarDTC_minibatch(object): self.batchsize = batchsize self.mpi_comm = mpi_comm + self.limit = limit # Cache functions from ...util.caching import Cacher @@ -37,6 +38,20 @@ class VarDTC_minibatch(object): self.midRes = {} self.batch_pos = 0 # the starting position of the current mini-batch + + def __getstate__(self): + # has to be overridden, as Cacher objects cannot be pickled. + return self.batchsize, self.limit + + def __setstate__(self, state): + # has to be overridden, as Cacher objects cannot be pickled. + self.batchsize, self.limit = state + self.mpi_comm = None + self.midRes = {} + self.batch_pos = 0 + from ...util.caching import Cacher + self.get_trYYT = Cacher(self._get_trYYT, self.limit) + self.get_YYTfactor = Cacher(self._get_YYTfactor, self.limit) def set_limit(self, limit): self.get_trYYT.limit = limit @@ -334,7 +349,10 @@ def update_gradients(model, mpi_comm=None): while not isEnd: isEnd, n_range, grad_dict = model.inference_method.inference_minibatch(model.kern, X, model.Z, model.likelihood, Y) if isinstance(model.X, VariationalPosterior): - X_slice = model.X[model.Y_range[0]+n_range[0]:model.Y_range[0]+n_range[1]] + if mpi_comm ==None: + X_slice = model.X[n_range[0]:n_range[1]] + else: + X_slice = model.X[model.Y_range[0]+n_range[0]:model.Y_range[0]+n_range[1]] #gradients w.r.t. kernel model.kern.update_gradients_expectations(variational_posterior=X_slice, Z=model.Z, dL_dpsi0=grad_dict['dL_dpsi0'], dL_dpsi1=grad_dict['dL_dpsi1'], dL_dpsi2=grad_dict['dL_dpsi2']) diff --git a/GPy/kern/_src/linear.py b/GPy/kern/_src/linear.py index 007649b0..6ef71724 100644 --- a/GPy/kern/_src/linear.py +++ b/GPy/kern/_src/linear.py @@ -52,6 +52,9 @@ class Linear(Kern): self.variances = Param('variances', variances, Logexp()) self.add_parameter(self.variances) + + def set_for_SpikeAndSlab(self): + self.psicomp = linear_psi_comp.PSICOMP_SSLinear() @Cache_this(limit=2) def K(self, X, X2=None): @@ -107,35 +110,20 @@ class Linear(Kern): def psi0(self, Z, variational_posterior): if isinstance(variational_posterior, variational.SpikeAndSlabPosterior): - gamma = variational_posterior.binary_prob - mu = variational_posterior.mean - S = variational_posterior.variance - - return np.einsum('q,nq,nq->n',self.variances,gamma,np.square(mu)+S) -# return (self.variances*gamma*(np.square(mu)+S)).sum(axis=1) + return self.psicomp.psicomputations(self.variances, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob)[0] else: return np.sum(self.variances * self._mu2S(variational_posterior), 1) def psi1(self, Z, variational_posterior): if isinstance(variational_posterior, variational.SpikeAndSlabPosterior): - gamma = variational_posterior.binary_prob - mu = variational_posterior.mean - return np.einsum('nq,q,mq,nq->nm',gamma,self.variances,Z,mu) -# return (self.variances*gamma*mu).sum(axis=1) + return self.psicomp.psicomputations(self.variances, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob)[1] else: return self.K(variational_posterior.mean, Z) #the variance, it does nothing @Cache_this(limit=1) def psi2(self, Z, variational_posterior): if isinstance(variational_posterior, variational.SpikeAndSlabPosterior): - gamma = variational_posterior.binary_prob - mu = variational_posterior.mean - S = variational_posterior.variance - mu2 = np.square(mu) - variances2 = np.square(self.variances) - tmp = np.einsum('nq,q,mq,nq->nm',gamma,self.variances,Z,mu) - return np.einsum('nq,q,mq,oq,nq->nmo',gamma,variances2,Z,Z,mu2+S)+\ - np.einsum('nm,no->nmo',tmp,tmp) - np.einsum('nq,q,mq,oq,nq->nmo',np.square(gamma),variances2,Z,Z,mu2) + return self.psicomp.psicomputations(self.variances, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob)[2] else: ZA = Z * self.variances ZAinner = self._ZAinner(variational_posterior, Z) @@ -143,17 +131,11 @@ class Linear(Kern): def update_gradients_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior): if isinstance(variational_posterior, variational.SpikeAndSlabPosterior): - gamma = variational_posterior.binary_prob - mu = variational_posterior.mean - S = variational_posterior.variance - mu2S = np.square(mu)+S - _dpsi2_dvariance, _, _, _, _ = linear_psi_comp._psi2computations(self.variances, Z, mu, S, gamma) - grad = np.einsum('n,nq,nq->q',dL_dpsi0,gamma,mu2S) + np.einsum('nm,nq,mq,nq->q',dL_dpsi1,gamma,Z,mu) +\ - np.einsum('nmo,nmoq->q',dL_dpsi2,_dpsi2_dvariance) + dL_dvar,_,_,_,_ = self.psicomp.psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, self.variances, Z, variational_posterior) if self.ARD: - self.variances.gradient = grad + self.variances.gradient = dL_dvar else: - self.variances.gradient = grad.sum() + self.variances.gradient = dL_dvar.sum() else: #psi1 self.update_gradients_full(dL_dpsi1, variational_posterior.mean, Z) @@ -170,15 +152,8 @@ class Linear(Kern): def gradients_Z_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior): if isinstance(variational_posterior, variational.SpikeAndSlabPosterior): - gamma = variational_posterior.binary_prob - mu = variational_posterior.mean - S = variational_posterior.variance - _, _, _, _, _dpsi2_dZ = linear_psi_comp._psi2computations(self.variances, Z, mu, S, gamma) - - grad = np.einsum('nm,nq,q,nq->mq',dL_dpsi1,gamma, self.variances,mu) +\ - np.einsum('nmo,noq->mq',dL_dpsi2,_dpsi2_dZ) - - return grad + _,dL_dZ,_,_,_ = self.psicomp.psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, self.variances, Z, variational_posterior) + return dL_dZ else: #psi1 grad = self.gradients_X(dL_dpsi1.T, Z, variational_posterior.mean) @@ -188,19 +163,8 @@ class Linear(Kern): def gradients_qX_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior): if isinstance(variational_posterior, variational.SpikeAndSlabPosterior): - gamma = variational_posterior.binary_prob - mu = variational_posterior.mean - S = variational_posterior.variance - mu2S = np.square(mu)+S - _, _dpsi2_dgamma, _dpsi2_dmu, _dpsi2_dS, _ = linear_psi_comp._psi2computations(self.variances, Z, mu, S, gamma) - - grad_gamma = np.einsum('n,q,nq->nq',dL_dpsi0,self.variances,mu2S) + np.einsum('nm,q,mq,nq->nq',dL_dpsi1,self.variances,Z,mu) +\ - np.einsum('nmo,nmoq->nq',dL_dpsi2,_dpsi2_dgamma) - grad_mu = np.einsum('n,nq,q,nq->nq',dL_dpsi0,gamma,2.*self.variances,mu) + np.einsum('nm,nq,q,mq->nq',dL_dpsi1,gamma,self.variances,Z) +\ - np.einsum('nmo,nmoq->nq',dL_dpsi2,_dpsi2_dmu) - grad_S = np.einsum('n,nq,q->nq',dL_dpsi0,gamma,self.variances) + np.einsum('nmo,nmoq->nq',dL_dpsi2,_dpsi2_dS) - - return grad_mu, grad_S, grad_gamma + _,_,dL_dmu, dL_dS, dL_dgamma = self.psicomp.psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, self.variances, Z, variational_posterior) + return dL_dmu, dL_dS, dL_dgamma else: grad_mu, grad_S = np.zeros(variational_posterior.mean.shape), np.zeros(variational_posterior.mean.shape) # psi0 diff --git a/GPy/kern/_src/psi_comp/linear_psi_comp.py b/GPy/kern/_src/psi_comp/linear_psi_comp.py index 22147366..03483b6b 100644 --- a/GPy/kern/_src/psi_comp/linear_psi_comp.py +++ b/GPy/kern/_src/psi_comp/linear_psi_comp.py @@ -8,44 +8,100 @@ The package for the Psi statistics computation of the linear kernel for SSGPLVM import numpy as np from GPy.util.caching import Cache_this -#@Cache_this(limit=1) -def _psi2computations(variance, Z, mu, S, gamma): - """ - Z - MxQ - mu - NxQ - S - NxQ - gamma - NxQ - """ - # here are the "statistics" for psi1 and psi2 - # Produced intermediate results: - # _psi2 NxMxM - # _psi2_dvariance NxMxMxQ - # _psi2_dZ NxMxQ - # _psi2_dgamma NxMxMxQ - # _psi2_dmu NxMxMxQ - # _psi2_dS NxMxMxQ - - mu2 = np.square(mu) - gamma2 = np.square(gamma) - variance2 = np.square(variance) - mu2S = mu2+S # NxQ - common_sum = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu) # NxM - - _dpsi2_dvariance = np.einsum('nq,q,mq,oq->nmoq',2.*(gamma*mu2S-gamma2*mu2),variance,Z,Z)+\ - np.einsum('nq,mq,nq,no->nmoq',gamma,Z,mu,common_sum)+\ - np.einsum('nq,oq,nq,nm->nmoq',gamma,Z,mu,common_sum) - - _dpsi2_dgamma = np.einsum('q,mq,oq,nq->nmoq',variance2,Z,Z,(mu2S-2.*gamma*mu2))+\ - np.einsum('q,mq,nq,no->nmoq',variance,Z,mu,common_sum)+\ - np.einsum('q,oq,nq,nm->nmoq',variance,Z,mu,common_sum) - - _dpsi2_dmu = np.einsum('q,mq,oq,nq,nq->nmoq',variance2,Z,Z,mu,2.*(gamma-gamma2))+\ - np.einsum('nq,q,mq,no->nmoq',gamma,variance,Z,common_sum)+\ - np.einsum('nq,q,oq,nm->nmoq',gamma,variance,Z,common_sum) - - _dpsi2_dS = np.einsum('nq,q,mq,oq->nmoq',gamma,variance2,Z,Z) - - _dpsi2_dZ = 2.*(np.einsum('nq,q,mq,nq->nmq',gamma,variance2,Z,mu2S)+np.einsum('nq,q,nq,nm->nmq',gamma,variance,mu,common_sum) - -np.einsum('nq,q,mq,nq->nmq',gamma2,variance2,Z,mu2)) +class PSICOMP_SSLinear(object): + #@Cache_this(limit=1, ignore_args=(0,)) + def psicomputations(self, variance, Z, mu, S, gamma): + """ + Compute psi-statistics for ss-linear kernel + """ + # here are the "statistics" for psi0, psi1 and psi2 + # Produced intermediate results: + # psi0 N + # psi1 NxM + # psi2 MxM - return _dpsi2_dvariance, _dpsi2_dgamma, _dpsi2_dmu, _dpsi2_dS, _dpsi2_dZ \ No newline at end of file + psi0 = np.einsum('q,nq,nq->n',variance,gamma,np.square(mu)+S) + psi1 = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu) + mu2 = np.square(mu) + variances2 = np.square(variance) + tmp = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu) + psi2 = np.einsum('nq,q,mq,oq,nq->mo',gamma,variances2,Z,Z,mu2+S)+\ + np.einsum('nm,no->mo',tmp,tmp) - np.einsum('nq,q,mq,oq,nq->mo',np.square(gamma),variances2,Z,Z,mu2) + + return psi0, psi1, psi2 + + #@Cache_this(limit=1, ignore_args=(0,1,2,3)) + def psiDerivativecomputations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, Z, variational_posterior): + mu = variational_posterior.mean + S = variational_posterior.variance + gamma = variational_posterior.binary_prob + + dL_dvar, dL_dgamma, dL_dmu, dL_dS, dL_dZ = self._psi2computations(dL_dpsi2, variance, Z, mu, S, gamma) + + # Compute for psi0 and psi1 + mu2S = np.square(mu)+S + dL_dvar += np.einsum('n,nq,nq->q',dL_dpsi0,gamma,mu2S) + np.einsum('nm,nq,mq,nq->q',dL_dpsi1,gamma,Z,mu) + dL_dgamma += np.einsum('n,q,nq->nq',dL_dpsi0,variance,mu2S) + np.einsum('nm,q,mq,nq->nq',dL_dpsi1,variance,Z,mu) + dL_dmu += np.einsum('n,nq,q,nq->nq',dL_dpsi0,gamma,2.*variance,mu) + np.einsum('nm,nq,q,mq->nq',dL_dpsi1,gamma,variance,Z) + dL_dS += np.einsum('n,nq,q->nq',dL_dpsi0,gamma,variance) + dL_dZ += np.einsum('nm,nq,q,nq->mq',dL_dpsi1,gamma, variance,mu) + + return dL_dvar, dL_dZ, dL_dmu, dL_dS, dL_dgamma + + def _psi2computations(self, dL_dpsi2, variance, Z, mu, S, gamma): + """ + Z - MxQ + mu - NxQ + S - NxQ + gamma - NxQ + """ + # here are the "statistics" for psi1 and psi2 + # Produced intermediate results: + # _psi2_dvariance Q + # _psi2_dZ MxQ + # _psi2_dgamma NxQ + # _psi2_dmu NxQ + # _psi2_dS NxQ + + mu2 = np.square(mu) + gamma2 = np.square(gamma) + variance2 = np.square(variance) + mu2S = mu2+S # NxQ + common_sum = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu) # NxM + +# _dpsi2_dvariance = np.einsum('nq,q,mq,oq->nmoq',2.*(gamma*mu2S-gamma2*mu2),variance,Z,Z)+\ +# np.einsum('nq,mq,nq,no->nmoq',gamma,Z,mu,common_sum)+\ +# np.einsum('nq,oq,nq,nm->nmoq',gamma,Z,mu,common_sum) +# +# _dpsi2_dgamma = np.einsum('q,mq,oq,nq->nmoq',variance2,Z,Z,(mu2S-2.*gamma*mu2))+\ +# np.einsum('q,mq,nq,no->nmoq',variance,Z,mu,common_sum)+\ +# np.einsum('q,oq,nq,nm->nmoq',variance,Z,mu,common_sum) +# + _dpsi2_dmu = np.einsum('q,mq,oq,nq,nq->nmoq',variance2,Z,Z,mu,2.*(gamma-gamma2))+\ + np.einsum('nq,q,mq,no->nmoq',gamma,variance,Z,common_sum)+\ + np.einsum('nq,q,oq,nm->nmoq',gamma,variance,Z,common_sum) +# +# _dpsi2_dS = np.einsum('nq,q,mq,oq->nmoq',gamma,variance2,Z,Z) +# +# _dpsi2_dZ = 2.*(np.einsum('nq,q,mq,nq->nmq',gamma,variance2,Z,mu2S)+np.einsum('nq,q,nq,nm->nmq',gamma,variance,mu,common_sum) +# -np.einsum('nq,q,mq,nq->nmq',gamma2,variance2,Z,mu2)) + dL_dmu = np.einsum('mo,nmoq->nq', dL_dpsi2, _dpsi2_dmu) + + dL_dvar = np.einsum('mo,nq,q,mq,oq->q',dL_dpsi2,2.*(gamma*mu2S-gamma2*mu2),variance,Z,Z)+\ + np.einsum('mo,nq,mq,nq,no->q',dL_dpsi2,gamma,Z,mu,common_sum)+\ + np.einsum('mo,nq,oq,nq,nm->q',dL_dpsi2,gamma,Z,mu,common_sum) + + dL_dgamma = np.einsum('mo,q,mq,oq,nq->nq',dL_dpsi2,variance2,Z,Z,(mu2S-2.*gamma*mu2))+\ + np.einsum('mo,q,mq,nq,no->nq',dL_dpsi2,variance,Z,mu,common_sum)+\ + np.einsum('mo,q,oq,nq,nm->nq',dL_dpsi2,variance,Z,mu,common_sum) + +# dL_dmu = np.einsum('mo,q,mq,oq,nq,nq->nq',dL_dpsi2,variance2,Z,Z,mu,2.*(gamma-gamma2))+\ +# np.einsum('mo,nq,q,mq,no->nq',dL_dpsi2,gamma,variance,Z,common_sum)+\ +# np.einsum('mo,nq,q,oq,nm->nq',dL_dpsi2,gamma,variance,Z,common_sum) + + dL_dS = np.einsum('mo,nq,q,mq,oq->nq',dL_dpsi2,gamma,variance2,Z,Z) + + dL_dZ = 2.*(np.einsum('om,nq,q,mq,nq->oq',dL_dpsi2,gamma,variance2,Z,mu2S)+np.einsum('om,nq,q,nq,nm->oq',dL_dpsi2,gamma,variance,mu,common_sum) + -np.einsum('om,nq,q,mq,nq->oq',dL_dpsi2,gamma2,variance2,Z,mu2)) + + return dL_dvar, dL_dgamma, dL_dmu, dL_dS, dL_dZ diff --git a/GPy/kern/_src/rbf.py b/GPy/kern/_src/rbf.py index 5944e765..f3590d40 100644 --- a/GPy/kern/_src/rbf.py +++ b/GPy/kern/_src/rbf.py @@ -9,7 +9,7 @@ from stationary import Stationary from GPy.util.caching import Cache_this from ...core.parameterization import variational from psi_comp import ssrbf_psi_comp -from psi_comp.ssrbf_psi_gpucomp import PSICOMP_SSRBF +from psi_comp import ssrbf_psi_gpucomp class RBF(Stationary): """ @@ -26,8 +26,11 @@ class RBF(Stationary): self.weave_options = {} self.group_spike_prob = False + def set_for_SpikeAndSlab(self): if self.useGPU: - self.psicomp = PSICOMP_SSRBF() + self.psicomp = ssrbf_psi_gpucomp.PSICOMP_SSRBF() + else: + self.psicomp = ssrbf_psi_comp def K_of_r(self, r): diff --git a/GPy/models/ss_gplvm.py b/GPy/models/ss_gplvm.py index cc57b191..76e2f0ef 100644 --- a/GPy/models/ss_gplvm.py +++ b/GPy/models/ss_gplvm.py @@ -44,8 +44,10 @@ class SSGPLVM(SparseGP): X_variance = np.random.uniform(0,.1,X.shape) gamma = np.empty_like(X, order='F') # The posterior probabilities of the binary variable in the variational approximation - #gamma[:] = 0.5 + 0.01 * np.random.randn(X.shape[0], input_dim) - gamma[:] = 0.5 + gamma[:] = 0.5 + 0.1 * np.random.randn(X.shape[0], input_dim) + gamma[gamma>=1. - 1e-9] = 1e-9 + gamma[gamma<1e-9] = 1e-9 + #gamma[:] = 0.5 if group_spike: gamma[:] = gamma.mean(axis=0) @@ -57,19 +59,20 @@ class SSGPLVM(SparseGP): pi = np.empty((input_dim)) pi[:] = 0.5 - if mpi_comm != None: - mpi_comm.Bcast(X, root=0) - mpi_comm.Bcast(fracs, root=0) - mpi_comm.Bcast(X_variance, root=0) - mpi_comm.Bcast(gamma, root=0) - mpi_comm.Bcast(Z, root=0) - mpi_comm.Bcast(pi, root=0) +# if mpi_comm != None: +# mpi_comm.Bcast(X, root=0) +# mpi_comm.Bcast(fracs, root=0) +# mpi_comm.Bcast(X_variance, root=0) +# mpi_comm.Bcast(gamma, root=0) +# mpi_comm.Bcast(Z, root=0) +# mpi_comm.Bcast(pi, root=0) if likelihood is None: likelihood = Gaussian() if kernel is None: kernel = kern.RBF(input_dim, lengthscale=fracs, ARD=True) # + kern.white(input_dim) + kernel.set_for_SpikeAndSlab() self.variational_prior = SpikeAndSlabPrior(pi=pi) # the prior probability of the latent binary variable b @@ -90,6 +93,7 @@ class SSGPLVM(SparseGP): self.X_local = self.X[Y_start:Y_end] self.Y_range = (Y_start, Y_end) self.Y_list = np.array(Y_list) + [mpi_comm.Bcast(p, root=0) for p in self.flattened_parameters] def set_X_gradients(self, X, X_grad): """Set the gradients of the posterior distribution of X in its specific form.""" @@ -125,3 +129,16 @@ class SSGPLVM(SparseGP): return dim_reduction_plots.plot_latent(self, plot_inducing=plot_inducing, *args, **kwargs) + def __getstate__(self): + dc = super(SSGPLVM, self).__getstate__() + del dc['mpi_comm'] + del dc['Y_local'] + del dc['X_local'] + return dc + + def __setstate__(self, state): + state['mpi_comm'] = None + Y_range = state['Y_range'] + state['Y_local'] = state['Y'][Y_range[0]:Y_range[1]] + state['X_local'] = state['X'][Y_range[0]:Y_range[1]] + return super(SSGPLVM, self).__setstate__(state)