an assortment of fixes

This commit is contained in:
James Hensman 2013-06-05 14:52:37 +01:00
parent 7040b26f41
commit 56a4bc4e21
5 changed files with 13 additions and 19 deletions

View file

@ -142,17 +142,17 @@ class SparseGP(GPBase):
def log_likelihood(self):
""" Compute the (lower bound on the) log marginal likelihood """
if self.likelihood.is_heteroscedastic:
A = -0.5 * self.N * self.input_dim * np.log(2.*np.pi) + 0.5 * np.sum(np.log(self.likelihood.precision)) - 0.5 * np.sum(self.likelihood.V * self.likelihood.Y)
B = -0.5 * self.input_dim * (np.sum(self.likelihood.precision.flatten() * self.psi0) - np.trace(self.A))
A = -0.5 * self.N * self.output_dim * np.log(2.*np.pi) + 0.5 * np.sum(np.log(self.likelihood.precision)) - 0.5 * np.sum(self.likelihood.V * self.likelihood.Y)
B = -0.5 * self.output_dim * (np.sum(self.likelihood.precision.flatten() * self.psi0) - np.trace(self.A))
else:
A = -0.5 * self.N * self.input_dim * (np.log(2.*np.pi) - np.log(self.likelihood.precision)) - 0.5 * self.likelihood.precision * self.likelihood.trYYT
B = -0.5 * self.input_dim * (np.sum(self.likelihood.precision * self.psi0) - np.trace(self.A))
C = -self.input_dim * (np.sum(np.log(np.diag(self.LB)))) # + 0.5 * self.num_inducing * np.log(sf2))
A = -0.5 * self.N * self.output_dim * (np.log(2.*np.pi) - np.log(self.likelihood.precision)) - 0.5 * self.likelihood.precision * self.likelihood.trYYT
B = -0.5 * self.output_dim * (np.sum(self.likelihood.precision * self.psi0) - np.trace(self.A))
C = -self.output_dim * (np.sum(np.log(np.diag(self.LB)))) # + 0.5 * self.num_inducing * np.log(sf2))
D = 0.5 * np.sum(np.square(self._LBi_Lmi_psi1V))
return A + B + C + D + self.likelihood.Z
def _set_params(self, p):
self.Z = p[:self.num_inducing * self.input_dim].reshape(self.num_inducing, self.input_dim)
self.Z = p[:self.num_inducing * self.output_dim].reshape(self.num_inducing, self.input_dim)
self.kern._set_params(p[self.Z.size:self.Z.size + self.kern.Nparam])
self.likelihood._set_params(p[self.Z.size + self.kern.Nparam:])
self._compute_kernel_matrices()