bug fix w.r.t. var_dtc.py

This commit is contained in:
Zhenwen Dai 2014-03-18 12:35:28 +00:00
parent 9e64f116d8
commit 5acb66cf78
2 changed files with 25 additions and 4 deletions

View file

@ -126,6 +126,27 @@ class SpikeAndSlabPosterior(VariationalPosterior):
super(SpikeAndSlabPosterior, self).__init__(means, variances, name) super(SpikeAndSlabPosterior, self).__init__(means, variances, name)
self.gamma = Param("binary_prob",binary_prob, Logistic(1e-10,1.-1e-10)) self.gamma = Param("binary_prob",binary_prob, Logistic(1e-10,1.-1e-10))
self.add_parameter(self.gamma) self.add_parameter(self.gamma)
def __getitem__(self, s):
if isinstance(s, (int, slice, tuple, list, np.ndarray)):
import copy
n = self.__new__(self.__class__, self.name)
dc = self.__dict__.copy()
dc['mean'] = self.mean[s]
dc['variance'] = self.variance[s]
dc['binary_prob'] = self.binary_prob[s]
dc['_parameters_'] = copy.copy(self._parameters_)
n.__dict__.update(dc)
n._parameters_[dc['mean']._parent_index_] = dc['mean']
n._parameters_[dc['variance']._parent_index_] = dc['variance']
n._parameters_[dc['binary_prob']._parent_index_] = dc['binary_prob']
n.ndim = n.mean.ndim
n.shape = n.mean.shape
n.num_data = n.mean.shape[0]
n.input_dim = n.mean.shape[1] if n.ndim != 1 else 1
return n
else:
return super(VariationalPrior, self).__getitem__(s)
def plot(self, *args): def plot(self, *args):
""" """

View file

@ -134,7 +134,7 @@ class VarDTC(object):
# log marginal likelihood # log marginal likelihood
log_marginal = _compute_log_marginal_likelihood(likelihood, num_data, output_dim, beta, het_noise, log_marginal = _compute_log_marginal_likelihood(likelihood, num_data, output_dim, beta, het_noise,
psi0, A, LB, trYYT, data_fit, Y) psi0, A, LB, trYYT, data_fit, VVT_factor)
#put the gradients in the right places #put the gradients in the right places
dL_dR = _compute_dL_dR(likelihood, dL_dR = _compute_dL_dR(likelihood,
@ -208,7 +208,7 @@ class VarDTCMissingData(object):
self._subarray_indices = [[slice(None),slice(None)]] self._subarray_indices = [[slice(None),slice(None)]]
return [Y], [(Y**2).sum()] return [Y], [(Y**2).sum()]
def inference(self, kern, X, Z, likelihood, Y): def inference(self, kern, X, Z, likelihood, Y, Y_metadata=None):
if isinstance(X, VariationalPosterior): if isinstance(X, VariationalPosterior):
uncertain_inputs = True uncertain_inputs = True
psi0_all = kern.psi0(Z, X) psi0_all = kern.psi0(Z, X)
@ -305,7 +305,7 @@ class VarDTCMissingData(object):
# log marginal likelihood # log marginal likelihood
log_marginal += _compute_log_marginal_likelihood(likelihood, num_data, output_dim, beta, het_noise, log_marginal += _compute_log_marginal_likelihood(likelihood, num_data, output_dim, beta, het_noise,
psi0, A, LB, trYYT, data_fit) psi0, A, LB, trYYT, data_fit,VVT_factor)
#put the gradients in the right places #put the gradients in the right places
dL_dR += _compute_dL_dR(likelihood, dL_dR += _compute_dL_dR(likelihood,
@ -420,7 +420,7 @@ def _compute_dL_dR(likelihood, het_noise, uncertain_inputs, LB, _LBi_Lmi_psi1Vf,
def _compute_log_marginal_likelihood(likelihood, num_data, output_dim, beta, het_noise, psi0, A, LB, trYYT, data_fit,Y): def _compute_log_marginal_likelihood(likelihood, num_data, output_dim, beta, het_noise, psi0, A, LB, trYYT, data_fit,Y):
#compute log marginal likelihood #compute log marginal likelihood
if het_noise: if het_noise:
lik_1 = -0.5 * num_data * output_dim * np.log(2. * np.pi) + 0.5 * np.sum(np.log(beta)) - 0.5 * np.sum(beta * Y**2) lik_1 = -0.5 * num_data * output_dim * np.log(2. * np.pi) + 0.5 * np.sum(np.log(beta)) - 0.5 * np.sum(beta * np.square(Y).sum(axis=-1))
lik_2 = -0.5 * output_dim * (np.sum(beta.flatten() * psi0) - np.trace(A)) lik_2 = -0.5 * output_dim * (np.sum(beta.flatten() * psi0) - np.trace(A))
else: else:
lik_1 = -0.5 * num_data * output_dim * (np.log(2. * np.pi) - np.log(beta)) - 0.5 * beta * trYYT lik_1 = -0.5 * num_data * output_dim * (np.log(2. * np.pi) - np.log(beta)) - 0.5 * beta * trYYT