Correct dl_dm term in student t inference

This commit is contained in:
Tom Whitehead 2024-02-03 18:00:15 +00:00
parent 0c5de0708a
commit 1184afad99

View file

@ -35,15 +35,20 @@ class ExactStudentTInference(LatentFunctionInference):
# Log marginal
N = Y.shape[0]
D = Y.shape[1]
log_marginal = 0.5 * (-N * np.log((nu - 2) * np.pi) - W_logdet - (nu + N) * np.log(1 + beta / (nu - 2)))
log_marginal = 0.5 * (
-N * np.log((nu - 2) * np.pi)
- W_logdet
- (nu + N) * np.log(1 + beta / (nu - 2))
)
log_marginal += gammaln((nu + N) / 2) - gammaln(nu / 2)
# Gradients
dL_dK = 0.5 * ((nu + N) / (nu + beta - 2) * tdot(alpha) - D * Wi)
dL_dnu = -N / (nu - 2.) + digamma(0.5 * (nu + N)) - digamma(0.5 * nu)
dL_dnu -= np.log(1 + beta / (nu - 2.))
dL_dnu = -N / (nu - 2.0) + digamma(0.5 * (nu + N)) - digamma(0.5 * nu)
dL_dnu -= np.log(1 + beta / (nu - 2.0))
dL_dnu += ((nu + N) * beta) / ((nu - 2) * (beta + nu - 2))
dL_dnu *= 0.5
gradients = {'dL_dK': dL_dK, 'dL_dnu': dL_dnu, 'dL_dm': alpha}
dL_dm = (nu + N) / (nu + beta - 2) * alpha
gradients = {"dL_dK": dL_dK, "dL_dnu": dL_dnu, "dL_dm": dL_dm}
return posterior, log_marginal, gradients