[minibatch] added coverage for branching, spottet bug in X_variance

This commit is contained in:
mzwiessele 2016-03-08 11:19:50 +00:00
parent 88a9b92c80
commit 67043e83d7
2 changed files with 13 additions and 12 deletions

View file

@ -40,12 +40,13 @@ class BayesianGPLVMMiniBatch(SparseGPMiniBatch):
Z = np.random.permutation(X.copy())[:num_inducing]
assert Z.shape[1] == X.shape[1]
if X_variance == False:
if X_variance is False:
self.logger.info('no variance on X, activating sparse GPLVM')
X = Param("latent space", X)
elif X_variance is None:
self.logger.info("initializing latent space variance ~ uniform(0,.1)")
X_variance = np.random.uniform(0,.1,X.shape)
else:
if X_variance is None:
self.logger.info("initializing latent space variance ~ uniform(0,.1)")
X_variance = np.random.uniform(0,.1,X.shape)
self.variational_prior = NormalPrior()
X = NormalPosterior(X, X_variance)
@ -71,13 +72,13 @@ class BayesianGPLVMMiniBatch(SparseGPMiniBatch):
self.X = X
self.link_parameter(self.X, 0)
def set_X_gradients(self, X, X_grad):
"""Set the gradients of the posterior distribution of X in its specific form."""
X.mean.gradient, X.variance.gradient = X_grad
#def set_X_gradients(self, X, X_grad):
# """Set the gradients of the posterior distribution of X in its specific form."""
# X.mean.gradient, X.variance.gradient = X_grad
def get_X_gradients(self, X):
"""Get the gradients of the posterior distribution of X in its specific form."""
return X.mean.gradient, X.variance.gradient
#def get_X_gradients(self, X):
# """Get the gradients of the posterior distribution of X in its specific form."""
# return X.mean.gradient, X.variance.gradient
def _outer_values_update(self, full_values):
"""
@ -122,7 +123,7 @@ class BayesianGPLVMMiniBatch(SparseGPMiniBatch):
if self.missing_data or not self.stochastics:
self._log_marginal_likelihood -= kl_fctr*self.variational_prior.KL_divergence(self.X)
elif self.stochastics:
else: #self.stochastics is given:
d = self.output_dim
self._log_marginal_likelihood -= kl_fctr*self.variational_prior.KL_divergence(self.X)*self.stochastics.batchsize/d