switch psi2 statistics design

This commit is contained in:
Zhenwen Dai 2014-05-06 16:30:17 +01:00
parent 7d41001ae1
commit 06336bf0d4
12 changed files with 245 additions and 124 deletions

View file

@ -11,7 +11,7 @@ from ..likelihoods import Gaussian
from ..inference.optimization import SCG
from ..util import linalg
from ..core.parameterization.variational import SpikeAndSlabPrior, SpikeAndSlabPosterior
from ..inference.latent_function_inference.var_dtc_parallel import update_gradients
from ..inference.latent_function_inference.var_dtc_parallel import update_gradients, VarDTC_minibatch
from ..inference.latent_function_inference.var_dtc_gpu import VarDTC_GPU
@ -41,7 +41,7 @@ class SSGPLVM(SparseGP):
if X_variance is None: # The variance of the variational approximation (S)
X_variance = np.random.uniform(0,.1,X.shape)
gamma = np.empty_like(X) # The posterior probabilities of the binary variable in the variational approximation
gamma = np.empty_like(X, order='F') # The posterior probabilities of the binary variable in the variational approximation
gamma[:] = 0.5 + 0.01 * np.random.randn(X.shape[0], input_dim)
if group_spike:
@ -60,12 +60,15 @@ class SSGPLVM(SparseGP):
pi = np.empty((input_dim))
pi[:] = 0.5
self.variational_prior = SpikeAndSlabPrior(pi=pi) # the prior probability of the latent binary variable b
X = np.asfortranarray(X)
X_variance = np.asfortranarray(X_variance)
gamma = np.asfortranarray(gamma)
X = SpikeAndSlabPosterior(X, X_variance, gamma)
if group_spike:
kernel.group_spike_prob = True
self.variational_prior.group_spike_prob = True
SparseGP.__init__(self, X, Y, Z, kernel, likelihood, inference_method, name, **kwargs)
self.add_parameter(self.X, index=0)
@ -76,7 +79,7 @@ class SSGPLVM(SparseGP):
X.mean.gradient, X.variance.gradient, X.binary_prob.gradient = X_grad
def parameters_changed(self):
if isinstance(self.inference_method, VarDTC_GPU):
if isinstance(self.inference_method, VarDTC_GPU) or isinstance(self.inference_method, VarDTC_minibatch):
update_gradients(self)
return