diff --git a/GPy/core/parameterization/variational.py b/GPy/core/parameterization/variational.py index 7cc5c99a..43e8d096 100644 --- a/GPy/core/parameterization/variational.py +++ b/GPy/core/parameterization/variational.py @@ -50,31 +50,29 @@ class SpikeAndSlabPrior(VariationalPrior): def KL_divergence(self, variational_posterior): mu = variational_posterior.mean S = variational_posterior.variance - gamma,gamma1 = variational_posterior.gamma_probabilities() - log_gamma,log_gamma1 = variational_posterior.gamma_log_prob() + gamma = variational_posterior.gamma.values if len(self.pi.shape)==2: - idx = np.unique(gamma._raveled_index()/gamma.shape[-1]) + idx = np.unique(variational_posterior.gamma._raveled_index()/gamma.shape[-1]) pi = self.pi[idx] else: pi = self.pi var_mean = np.square(mu)/self.variance var_S = (S/self.variance - np.log(S)) - var_gamma = (gamma*(log_gamma-np.log(pi))).sum()+(gamma1*(log_gamma1-np.log(1-pi))).sum() + var_gamma = (gamma*np.log(gamma/pi)).sum()+((1-gamma)*np.log((1-gamma)/(1-pi))).sum() return var_gamma+ (gamma* (np.log(self.variance)-1. +var_mean + var_S)).sum()/2. def update_gradients_KL(self, variational_posterior): mu = variational_posterior.mean S = variational_posterior.variance - gamma,gamma1 = variational_posterior.gamma_probabilities() - log_gamma,log_gamma1 = variational_posterior.gamma_log_prob() + gamma = variational_posterior.gamma.values if len(self.pi.shape)==2: - idx = np.unique(gamma._raveled_index()/gamma.shape[-1]) + idx = np.unique(variational_posterior.gamma._raveled_index()/gamma.shape[-1]) pi = self.pi[idx] else: pi = self.pi - variational_posterior.binary_prob.gradient -= (np.log((1-pi)/pi)+log_gamma-log_gamma1+((np.square(mu)+S)/self.variance-np.log(S)+np.log(self.variance)-1.)/2.)*gamma*gamma1 + variational_posterior.binary_prob.gradient -= np.log((1-pi)/pi*gamma/(1.-gamma))+((np.square(mu)+S)/self.variance-np.log(S)+np.log(self.variance)-1.)/2. mu.gradient -= gamma*mu/self.variance S.gradient -= (1./self.variance - 1./S) * gamma /2. if self.learnPi: @@ -162,24 +160,8 @@ class SpikeAndSlabPosterior(VariationalPosterior): binary_prob : the probability of the distribution on the slab part. """ super(SpikeAndSlabPosterior, self).__init__(means, variances, name) - self.gamma = Param("binary_prob",binary_prob) + self.gamma = Param("binary_prob",binary_prob,Logistic(0.,1.)) self.link_parameter(self.gamma) - - @Cache_this(limit=5) - def gamma_probabilities(self): - prob = np.zeros_like(param_to_array(self.gamma)) - prob[self.gamma>-710] = 1./(1.+np.exp(-self.gamma[self.gamma>-710])) - prob1 = -np.zeros_like(param_to_array(self.gamma)) - prob1[self.gamma<710] = 1./(1.+np.exp(self.gamma[self.gamma<710])) - return prob, prob1 - - @Cache_this(limit=5) - def gamma_log_prob(self): - loggamma = param_to_array(self.gamma).copy() - loggamma[loggamma>-40] = -np.log1p(np.exp(-loggamma[loggamma>-40])) - loggamma1 = -param_to_array(self.gamma).copy() - loggamma1[loggamma1>-40] = -np.log1p(np.exp(-loggamma1[loggamma1>-40])) - return loggamma,loggamma1 def set_gradients(self, grad): self.mean.gradient, self.variance.gradient, self.gamma.gradient = grad diff --git a/GPy/inference/latent_function_inference/var_dtc_parallel.py b/GPy/inference/latent_function_inference/var_dtc_parallel.py index cac69872..2e633e16 100644 --- a/GPy/inference/latent_function_inference/var_dtc_parallel.py +++ b/GPy/inference/latent_function_inference/var_dtc_parallel.py @@ -169,11 +169,13 @@ class VarDTC_minibatch(LatentFunctionInference): Kmm = kern.K(Z).copy() diag.add(Kmm, self.const_jitter) - Lm = jitchol(Kmm, maxtries=100) + if not np.isfinite(Kmm).all(): + print Kmm + Lm = jitchol(Kmm) LmInvPsi2LmInvT = backsub_both_sides(Lm,psi2_full,transpose='right') Lambda = np.eye(Kmm.shape[0])+LmInvPsi2LmInvT - LL = jitchol(Lambda, maxtries=100) + LL = jitchol(Lambda) logdet_L = 2.*np.sum(np.log(np.diag(LL))) b = dtrtrs(LL,dtrtrs(Lm,psi1Y_full.T)[0])[0] bbt = np.square(b).sum() diff --git a/GPy/kern/_src/psi_comp/ssrbf_psi_comp.py b/GPy/kern/_src/psi_comp/ssrbf_psi_comp.py index 18a4d751..f6a24c86 100644 --- a/GPy/kern/_src/psi_comp/ssrbf_psi_comp.py +++ b/GPy/kern/_src/psi_comp/ssrbf_psi_comp.py @@ -22,12 +22,14 @@ try: # _psi1 NxM mu = variational_posterior.mean S = variational_posterior.variance + gamma = variational_posterior.binary_prob N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1] l2 = np.square(lengthscale) log_denom1 = np.log(S/l2+1) log_denom2 = np.log(2*S/l2+1) - log_gamma,log_gamma1 = variational_posterior.gamma_log_prob() + log_gamma = np.log(gamma) + log_gamma1 = np.log(1.-gamma) variance = float(variance) psi0 = np.empty(N) psi0[:] = variance @@ -37,6 +39,7 @@ try: from ....util.misc import param_to_array S = param_to_array(S) mu = param_to_array(mu) + gamma = param_to_array(gamma) Z = param_to_array(Z) support_code = """ @@ -79,7 +82,7 @@ try: } } """ - weave.inline(code, support_code=support_code, arg_names=['psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','log_denom1','log_denom2','log_gamma','log_gamma1'], type_converters=weave.converters.blitz) + weave.inline(code, support_code=support_code, arg_names=['psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','log_denom1','log_denom2','log_gamma','log_gamma1'], type_converters=weave.converters.blitz) psi2 = psi2n.sum(axis=0) return psi0,psi1,psi2,psi2n @@ -94,12 +97,13 @@ try: mu = variational_posterior.mean S = variational_posterior.variance + gamma = variational_posterior.binary_prob N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1] l2 = np.square(lengthscale) log_denom1 = np.log(S/l2+1) log_denom2 = np.log(2*S/l2+1) - log_gamma,log_gamma1 = variational_posterior.gamma_log_prob() - gamma, gamma1 = variational_posterior.gamma_probabilities() + log_gamma = np.log(gamma) + log_gamma1 = np.log(1.-gamma) variance = float(variance) dvar = np.zeros(1) @@ -113,6 +117,7 @@ try: from ....util.misc import param_to_array S = param_to_array(S) mu = param_to_array(mu) + gamma = param_to_array(gamma) Z = param_to_array(Z) support_code = """ @@ -130,7 +135,6 @@ try: double Zm1q = Z(m1,q); double Zm2q = Z(m2,q); double gnq = gamma(n,q); - double g1nq = gamma1(n,q); double mu_nq = mu(n,q); if(m2==0) { @@ -156,7 +160,7 @@ try: dmu(n,q) += lpsi1*Zmu*d_exp1/(denom*exp_sum); dS(n,q) += lpsi1*(Zmu2_denom-1.)*d_exp1/(denom*exp_sum)/2.; - dgamma(n,q) += lpsi1*(d_exp1*g1nq-d_exp2*gnq)/exp_sum; + dgamma(n,q) += lpsi1*(d_exp1/gnq-d_exp2/(1.-gnq))/exp_sum; dl(q) += lpsi1*((Zmu2_denom+Snq/lq)/denom*d_exp1+Zm1q*Zm1q/(lq*lq)*d_exp2)/(2.*exp_sum); dZ(m1,q) += lpsi1*(-Zmu/denom*d_exp1-Zm1q/lq*d_exp2)/exp_sum; } @@ -184,7 +188,7 @@ try: dmu(n,q) += -2.*lpsi2*muZhat/denom*d_exp1/exp_sum; dS(n,q) += lpsi2*(2.*muZhat2_denom-1.)/denom*d_exp1/exp_sum; - dgamma(n,q) += lpsi2*(d_exp1*g1nq-d_exp2*gnq)/exp_sum; + dgamma(n,q) += lpsi2*(d_exp1/gnq-d_exp2/(1.-gnq))/exp_sum; dl(q) += lpsi2*(((Snq/lq+muZhat2_denom)/denom+dZm1m2*dZm1m2/(4.*lq*lq))*d_exp1+Z2/(2.*lq*lq)*d_exp2)/exp_sum; dZ(m1,q) += 2.*lpsi2*((muZhat/denom-dZm1m2/(2*lq))*d_exp1-Zm1q/lq*d_exp2)/exp_sum; } @@ -192,7 +196,7 @@ try: } } """ - weave.inline(code, support_code=support_code, arg_names=['dL_dpsi1','dL_dpsi2','psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','gamma1','log_denom1','log_denom2','log_gamma','log_gamma1','dvar','dl','dmu','dS','dgamma','dZ'], type_converters=weave.converters.blitz) + weave.inline(code, support_code=support_code, arg_names=['dL_dpsi1','dL_dpsi2','psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','log_denom1','log_denom2','log_gamma','log_gamma1','dvar','dl','dmu','dS','dgamma','dZ'], type_converters=weave.converters.blitz) dl *= 2.*lengthscale if not ARD: diff --git a/GPy/models/ss_gplvm.py b/GPy/models/ss_gplvm.py index a61ad2a0..04006d84 100644 --- a/GPy/models/ss_gplvm.py +++ b/GPy/models/ss_gplvm.py @@ -39,7 +39,10 @@ class SSGPLVM(SparseGP_MPI): X_variance = np.random.uniform(0,.1,X.shape) if Gamma is None: - gamma = np.random.randn(X.shape[0], input_dim) + gamma = np.empty_like(X) # The posterior probabilities of the binary variable in the variational approximation + gamma[:] = 0.5 + 0.1 * np.random.randn(X.shape[0], input_dim) + gamma[gamma>1.-1e-9] = 1.-1e-9 + gamma[gamma<1e-9] = 1e-9 else: gamma = Gamma.copy()