mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-12 21:42:39 +02:00
further bug fix for sparsegp_mpi
This commit is contained in:
parent
1e1bbb2a26
commit
7ed0e70a46
3 changed files with 24 additions and 16 deletions
|
|
@ -46,9 +46,15 @@ class SpikeAndSlabPrior(VariationalPrior):
|
|||
mu = variational_posterior.mean
|
||||
S = variational_posterior.variance
|
||||
gamma = variational_posterior.binary_prob
|
||||
if len(self.pi.shape)==2:
|
||||
idx = np.unique(gamma._raveled_index()/gamma.shape[-1])
|
||||
pi = self.pi[idx]
|
||||
else:
|
||||
pi = self.pi
|
||||
|
||||
var_mean = np.square(mu)/self.variance
|
||||
var_S = (S/self.variance - np.log(S))
|
||||
var_gamma = (gamma*np.log(gamma/self.pi)).sum()+((1-gamma)*np.log((1-gamma)/(1-self.pi))).sum()
|
||||
var_gamma = (gamma*np.log(gamma/pi)).sum()+((1-gamma)*np.log((1-gamma)/(1-pi))).sum()
|
||||
return var_gamma+ (gamma* (np.log(self.variance)-1. +var_mean + var_S)).sum()/2.
|
||||
|
||||
def update_gradients_KL(self, variational_posterior):
|
||||
|
|
|
|||
|
|
@ -95,12 +95,12 @@ class SparseGP_MPI(SparseGP):
|
|||
super(SparseGP_MPI, self).optimize(optimizer,start,**kwargs)
|
||||
self.mpi_comm.Bcast(np.int32(-1),root=0)
|
||||
elif self.mpi_comm.rank>0:
|
||||
x = self._get_params_transformed().copy()
|
||||
x = self.optimizer_array.copy()
|
||||
flag = np.empty(1,dtype=np.int32)
|
||||
while True:
|
||||
self.mpi_comm.Bcast(flag,root=0)
|
||||
if flag==1:
|
||||
self._set_params_transformed(x)
|
||||
self.optimizer_array = x
|
||||
elif flag==-1:
|
||||
break
|
||||
else:
|
||||
|
|
@ -109,5 +109,8 @@ class SparseGP_MPI(SparseGP):
|
|||
self._IN_OPTIMIZATION_ = False
|
||||
|
||||
def parameters_changed(self):
|
||||
update_gradients(self, mpi_comm=self.mpi_comm)
|
||||
if isinstance(self.inference_method,VarDTC_minibatch):
|
||||
update_gradients(self, mpi_comm=self.mpi_comm)
|
||||
else:
|
||||
super(SparseGP_MPI,self).parameters_changed()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue