diff --git a/GPy/core/parameterization/variational.py b/GPy/core/parameterization/variational.py index 01706922..5b3c4bca 100644 --- a/GPy/core/parameterization/variational.py +++ b/GPy/core/parameterization/variational.py @@ -126,6 +126,27 @@ class SpikeAndSlabPosterior(VariationalPosterior): super(SpikeAndSlabPosterior, self).__init__(means, variances, name) self.gamma = Param("binary_prob",binary_prob, Logistic(1e-10,1.-1e-10)) self.add_parameter(self.gamma) + + def __getitem__(self, s): + if isinstance(s, (int, slice, tuple, list, np.ndarray)): + import copy + n = self.__new__(self.__class__, self.name) + dc = self.__dict__.copy() + dc['mean'] = self.mean[s] + dc['variance'] = self.variance[s] + dc['binary_prob'] = self.binary_prob[s] + dc['_parameters_'] = copy.copy(self._parameters_) + n.__dict__.update(dc) + n._parameters_[dc['mean']._parent_index_] = dc['mean'] + n._parameters_[dc['variance']._parent_index_] = dc['variance'] + n._parameters_[dc['binary_prob']._parent_index_] = dc['binary_prob'] + n.ndim = n.mean.ndim + n.shape = n.mean.shape + n.num_data = n.mean.shape[0] + n.input_dim = n.mean.shape[1] if n.ndim != 1 else 1 + return n + else: + return super(VariationalPrior, self).__getitem__(s) def plot(self, *args): """ diff --git a/GPy/inference/latent_function_inference/var_dtc.py b/GPy/inference/latent_function_inference/var_dtc.py index 82f6c2b9..e2aa95f5 100644 --- a/GPy/inference/latent_function_inference/var_dtc.py +++ b/GPy/inference/latent_function_inference/var_dtc.py @@ -134,7 +134,7 @@ class VarDTC(object): # log marginal likelihood log_marginal = _compute_log_marginal_likelihood(likelihood, num_data, output_dim, beta, het_noise, - psi0, A, LB, trYYT, data_fit, Y) + psi0, A, LB, trYYT, data_fit, VVT_factor) #put the gradients in the right places dL_dR = _compute_dL_dR(likelihood, @@ -208,7 +208,7 @@ class VarDTCMissingData(object): self._subarray_indices = [[slice(None),slice(None)]] return [Y], [(Y**2).sum()] - def inference(self, kern, X, Z, likelihood, Y): + def inference(self, kern, X, Z, likelihood, Y, Y_metadata=None): if isinstance(X, VariationalPosterior): uncertain_inputs = True psi0_all = kern.psi0(Z, X) @@ -305,7 +305,7 @@ class VarDTCMissingData(object): # log marginal likelihood log_marginal += _compute_log_marginal_likelihood(likelihood, num_data, output_dim, beta, het_noise, - psi0, A, LB, trYYT, data_fit) + psi0, A, LB, trYYT, data_fit,VVT_factor) #put the gradients in the right places dL_dR += _compute_dL_dR(likelihood, @@ -420,7 +420,7 @@ def _compute_dL_dR(likelihood, het_noise, uncertain_inputs, LB, _LBi_Lmi_psi1Vf, def _compute_log_marginal_likelihood(likelihood, num_data, output_dim, beta, het_noise, psi0, A, LB, trYYT, data_fit,Y): #compute log marginal likelihood if het_noise: - lik_1 = -0.5 * num_data * output_dim * np.log(2. * np.pi) + 0.5 * np.sum(np.log(beta)) - 0.5 * np.sum(beta * Y**2) + lik_1 = -0.5 * num_data * output_dim * np.log(2. * np.pi) + 0.5 * np.sum(np.log(beta)) - 0.5 * np.sum(beta * np.square(Y).sum(axis=-1)) lik_2 = -0.5 * output_dim * (np.sum(beta.flatten() * psi0) - np.trace(A)) else: lik_1 = -0.5 * num_data * output_dim * (np.log(2. * np.pi) - np.log(beta)) - 0.5 * beta * trYYT