further bug fix for sparsegp_mpi

This commit is contained in:
Zhenwen Dai 2014-08-27 09:45:06 +01:00
parent 1e1bbb2a26
commit 7ed0e70a46
3 changed files with 24 additions and 16 deletions

View file

@ -46,9 +46,15 @@ class SpikeAndSlabPrior(VariationalPrior):
mu = variational_posterior.mean
S = variational_posterior.variance
gamma = variational_posterior.binary_prob
if len(self.pi.shape)==2:
idx = np.unique(gamma._raveled_index()/gamma.shape[-1])
pi = self.pi[idx]
else:
pi = self.pi
var_mean = np.square(mu)/self.variance
var_S = (S/self.variance - np.log(S))
var_gamma = (gamma*np.log(gamma/self.pi)).sum()+((1-gamma)*np.log((1-gamma)/(1-self.pi))).sum()
var_gamma = (gamma*np.log(gamma/pi)).sum()+((1-gamma)*np.log((1-gamma)/(1-pi))).sum()
return var_gamma+ (gamma* (np.log(self.variance)-1. +var_mean + var_S)).sum()/2.
def update_gradients_KL(self, variational_posterior):

View file

@ -95,12 +95,12 @@ class SparseGP_MPI(SparseGP):
super(SparseGP_MPI, self).optimize(optimizer,start,**kwargs)
self.mpi_comm.Bcast(np.int32(-1),root=0)
elif self.mpi_comm.rank>0:
x = self._get_params_transformed().copy()
x = self.optimizer_array.copy()
flag = np.empty(1,dtype=np.int32)
while True:
self.mpi_comm.Bcast(flag,root=0)
if flag==1:
self._set_params_transformed(x)
self.optimizer_array = x
elif flag==-1:
break
else:
@ -109,5 +109,8 @@ class SparseGP_MPI(SparseGP):
self._IN_OPTIMIZATION_ = False
def parameters_changed(self):
update_gradients(self, mpi_comm=self.mpi_comm)
if isinstance(self.inference_method,VarDTC_minibatch):
update_gradients(self, mpi_comm=self.mpi_comm)
else:
super(SparseGP_MPI,self).parameters_changed()

View file

@ -84,18 +84,17 @@ class SSGPLVM(SparseGP_MPI):
"""Get the gradients of the posterior distribution of X in its specific form."""
return X.mean.gradient, X.variance.gradient, X.binary_prob.gradient
# def parameters_changed(self):
# if isinstance(self.inference_method, VarDTC_GPU) or isinstance(self.inference_method, VarDTC_minibatch):
# update_gradients(self, mpi_comm=self.mpi_comm)
# return
#
# super(SSGPLVM, self).parameters_changed()
# self._log_marginal_likelihood -= self.variational_prior.KL_divergence(self.X)
#
# self.X.mean.gradient, self.X.variance.gradient, self.X.binary_prob.gradient = self.kern.gradients_qX_expectations(variational_posterior=self.X, Z=self.Z, dL_dpsi0=self.grad_dict['dL_dpsi0'], dL_dpsi1=self.grad_dict['dL_dpsi1'], dL_dpsi2=self.grad_dict['dL_dpsi2'])
#
# # update for the KL divergence
# self.variational_prior.update_gradients_KL(self.X)
def parameters_changed(self):
super(SSGPLVM,self).parameters_changed()
if isinstance(self.inference_method, VarDTC_minibatch):
return
self._log_marginal_likelihood -= self.variational_prior.KL_divergence(self.X)
self.X.mean.gradient, self.X.variance.gradient, self.X.binary_prob.gradient = self.kern.gradients_qX_expectations(variational_posterior=self.X, Z=self.Z, dL_dpsi0=self.grad_dict['dL_dpsi0'], dL_dpsi1=self.grad_dict['dL_dpsi1'], dL_dpsi2=self.grad_dict['dL_dpsi2'])
# update for the KL divergence
self.variational_prior.update_gradients_KL(self.X)
def input_sensitivity(self):
if self.kern.ARD: