diff --git a/GPy/inference/latent_function_inference/var_dtc.py b/GPy/inference/latent_function_inference/var_dtc.py index 64ee30c4..70637e3b 100644 --- a/GPy/inference/latent_function_inference/var_dtc.py +++ b/GPy/inference/latent_function_inference/var_dtc.py @@ -246,12 +246,10 @@ class VarDTCMissingData(LatentFunctionInference): uncertain_inputs = True psi0_all = kern.psi0(Z, X) psi1_all = kern.psi1(Z, X) - psi2_all = kern.psi2(Z, X) else: uncertain_inputs = False psi0_all = kern.Kdiag(X) psi1_all = kern.K(X, Z) - psi2_all = None Ys, traces = self._Y(Y) beta_all = 1./np.fmax(likelihood.gaussian_variance(Y_metadata), 1e-6) @@ -262,7 +260,7 @@ class VarDTCMissingData(LatentFunctionInference): dL_dpsi0_all = np.zeros(Y.shape[0]) dL_dpsi1_all = np.zeros((Y.shape[0], num_inducing)) if uncertain_inputs: - dL_dpsi2_all = np.zeros((Y.shape[0], num_inducing, num_inducing)) + dL_dpsi2_all = np.zeros((num_inducing, num_inducing)) dL_dR = 0 woodbury_vector = np.zeros((num_inducing, Y.shape[1])) @@ -278,6 +276,7 @@ class VarDTCMissingData(LatentFunctionInference): size = Y.shape[1] next_ten = 0 + for i, [y, v, trYYT] in enumerate(itertools.izip(Ys, self._inan.T, traces)): if ((i+1.)/size) >= next_ten: logger.info('inference {:> 6.1%}'.format((i+1.)/size)) @@ -290,13 +289,13 @@ class VarDTCMissingData(LatentFunctionInference): psi0 = psi0_all[v] psi1 = psi1_all[v, :] - if uncertain_inputs: psi2 = psi2_all[v, :] + if uncertain_inputs: psi2 = kern.psi2(Z, X[v, :]) else: psi2 = None num_data = psi1.shape[0] if uncertain_inputs: if het_noise: psi2_beta = psi2 * (beta.flatten().reshape(num_data, 1, 1)).sum(0) - else: psi2_beta = psi2.sum(0) * beta + else: psi2_beta = psi2 * beta A = LmInv.dot(psi2_beta.dot(LmInv.T)) else: if het_noise: tmp = psi1 * (np.sqrt(beta.reshape(num_data, 1))) @@ -331,7 +330,7 @@ class VarDTCMissingData(LatentFunctionInference): dL_dpsi0_all[v] += dL_dpsi0 dL_dpsi1_all[v, :] += dL_dpsi1 if uncertain_inputs: - dL_dpsi2_all[v, :] += dL_dpsi2 + dL_dpsi2_all += dL_dpsi2 # log marginal likelihood log_marginal += _compute_log_marginal_likelihood(likelihood, num_data, output_dim, beta, het_noise,