[vardtc missing data] updated to new psi2 stuff

This commit is contained in:
Max Zwiessele 2014-10-06 11:00:34 +01:00
parent 6a260409fa
commit 919be3ceba

View file

@ -246,12 +246,10 @@ class VarDTCMissingData(LatentFunctionInference):
uncertain_inputs = True uncertain_inputs = True
psi0_all = kern.psi0(Z, X) psi0_all = kern.psi0(Z, X)
psi1_all = kern.psi1(Z, X) psi1_all = kern.psi1(Z, X)
psi2_all = kern.psi2(Z, X)
else: else:
uncertain_inputs = False uncertain_inputs = False
psi0_all = kern.Kdiag(X) psi0_all = kern.Kdiag(X)
psi1_all = kern.K(X, Z) psi1_all = kern.K(X, Z)
psi2_all = None
Ys, traces = self._Y(Y) Ys, traces = self._Y(Y)
beta_all = 1./np.fmax(likelihood.gaussian_variance(Y_metadata), 1e-6) beta_all = 1./np.fmax(likelihood.gaussian_variance(Y_metadata), 1e-6)
@ -262,7 +260,7 @@ class VarDTCMissingData(LatentFunctionInference):
dL_dpsi0_all = np.zeros(Y.shape[0]) dL_dpsi0_all = np.zeros(Y.shape[0])
dL_dpsi1_all = np.zeros((Y.shape[0], num_inducing)) dL_dpsi1_all = np.zeros((Y.shape[0], num_inducing))
if uncertain_inputs: if uncertain_inputs:
dL_dpsi2_all = np.zeros((Y.shape[0], num_inducing, num_inducing)) dL_dpsi2_all = np.zeros((num_inducing, num_inducing))
dL_dR = 0 dL_dR = 0
woodbury_vector = np.zeros((num_inducing, Y.shape[1])) woodbury_vector = np.zeros((num_inducing, Y.shape[1]))
@ -278,6 +276,7 @@ class VarDTCMissingData(LatentFunctionInference):
size = Y.shape[1] size = Y.shape[1]
next_ten = 0 next_ten = 0
for i, [y, v, trYYT] in enumerate(itertools.izip(Ys, self._inan.T, traces)): for i, [y, v, trYYT] in enumerate(itertools.izip(Ys, self._inan.T, traces)):
if ((i+1.)/size) >= next_ten: if ((i+1.)/size) >= next_ten:
logger.info('inference {:> 6.1%}'.format((i+1.)/size)) logger.info('inference {:> 6.1%}'.format((i+1.)/size))
@ -290,13 +289,13 @@ class VarDTCMissingData(LatentFunctionInference):
psi0 = psi0_all[v] psi0 = psi0_all[v]
psi1 = psi1_all[v, :] psi1 = psi1_all[v, :]
if uncertain_inputs: psi2 = psi2_all[v, :] if uncertain_inputs: psi2 = kern.psi2(Z, X[v, :])
else: psi2 = None else: psi2 = None
num_data = psi1.shape[0] num_data = psi1.shape[0]
if uncertain_inputs: if uncertain_inputs:
if het_noise: psi2_beta = psi2 * (beta.flatten().reshape(num_data, 1, 1)).sum(0) if het_noise: psi2_beta = psi2 * (beta.flatten().reshape(num_data, 1, 1)).sum(0)
else: psi2_beta = psi2.sum(0) * beta else: psi2_beta = psi2 * beta
A = LmInv.dot(psi2_beta.dot(LmInv.T)) A = LmInv.dot(psi2_beta.dot(LmInv.T))
else: else:
if het_noise: tmp = psi1 * (np.sqrt(beta.reshape(num_data, 1))) if het_noise: tmp = psi1 * (np.sqrt(beta.reshape(num_data, 1)))
@ -331,7 +330,7 @@ class VarDTCMissingData(LatentFunctionInference):
dL_dpsi0_all[v] += dL_dpsi0 dL_dpsi0_all[v] += dL_dpsi0
dL_dpsi1_all[v, :] += dL_dpsi1 dL_dpsi1_all[v, :] += dL_dpsi1
if uncertain_inputs: if uncertain_inputs:
dL_dpsi2_all[v, :] += dL_dpsi2 dL_dpsi2_all += dL_dpsi2
# log marginal likelihood # log marginal likelihood
log_marginal += _compute_log_marginal_likelihood(likelihood, num_data, output_dim, beta, het_noise, log_marginal += _compute_log_marginal_likelihood(likelihood, num_data, output_dim, beta, het_noise,