mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-05 14:55:15 +02:00
[vardtc missing data] updated to new psi2 stuff
This commit is contained in:
parent
6a260409fa
commit
919be3ceba
1 changed files with 5 additions and 6 deletions
|
|
@ -246,12 +246,10 @@ class VarDTCMissingData(LatentFunctionInference):
|
||||||
uncertain_inputs = True
|
uncertain_inputs = True
|
||||||
psi0_all = kern.psi0(Z, X)
|
psi0_all = kern.psi0(Z, X)
|
||||||
psi1_all = kern.psi1(Z, X)
|
psi1_all = kern.psi1(Z, X)
|
||||||
psi2_all = kern.psi2(Z, X)
|
|
||||||
else:
|
else:
|
||||||
uncertain_inputs = False
|
uncertain_inputs = False
|
||||||
psi0_all = kern.Kdiag(X)
|
psi0_all = kern.Kdiag(X)
|
||||||
psi1_all = kern.K(X, Z)
|
psi1_all = kern.K(X, Z)
|
||||||
psi2_all = None
|
|
||||||
|
|
||||||
Ys, traces = self._Y(Y)
|
Ys, traces = self._Y(Y)
|
||||||
beta_all = 1./np.fmax(likelihood.gaussian_variance(Y_metadata), 1e-6)
|
beta_all = 1./np.fmax(likelihood.gaussian_variance(Y_metadata), 1e-6)
|
||||||
|
|
@ -262,7 +260,7 @@ class VarDTCMissingData(LatentFunctionInference):
|
||||||
dL_dpsi0_all = np.zeros(Y.shape[0])
|
dL_dpsi0_all = np.zeros(Y.shape[0])
|
||||||
dL_dpsi1_all = np.zeros((Y.shape[0], num_inducing))
|
dL_dpsi1_all = np.zeros((Y.shape[0], num_inducing))
|
||||||
if uncertain_inputs:
|
if uncertain_inputs:
|
||||||
dL_dpsi2_all = np.zeros((Y.shape[0], num_inducing, num_inducing))
|
dL_dpsi2_all = np.zeros((num_inducing, num_inducing))
|
||||||
|
|
||||||
dL_dR = 0
|
dL_dR = 0
|
||||||
woodbury_vector = np.zeros((num_inducing, Y.shape[1]))
|
woodbury_vector = np.zeros((num_inducing, Y.shape[1]))
|
||||||
|
|
@ -278,6 +276,7 @@ class VarDTCMissingData(LatentFunctionInference):
|
||||||
|
|
||||||
size = Y.shape[1]
|
size = Y.shape[1]
|
||||||
next_ten = 0
|
next_ten = 0
|
||||||
|
|
||||||
for i, [y, v, trYYT] in enumerate(itertools.izip(Ys, self._inan.T, traces)):
|
for i, [y, v, trYYT] in enumerate(itertools.izip(Ys, self._inan.T, traces)):
|
||||||
if ((i+1.)/size) >= next_ten:
|
if ((i+1.)/size) >= next_ten:
|
||||||
logger.info('inference {:> 6.1%}'.format((i+1.)/size))
|
logger.info('inference {:> 6.1%}'.format((i+1.)/size))
|
||||||
|
|
@ -290,13 +289,13 @@ class VarDTCMissingData(LatentFunctionInference):
|
||||||
|
|
||||||
psi0 = psi0_all[v]
|
psi0 = psi0_all[v]
|
||||||
psi1 = psi1_all[v, :]
|
psi1 = psi1_all[v, :]
|
||||||
if uncertain_inputs: psi2 = psi2_all[v, :]
|
if uncertain_inputs: psi2 = kern.psi2(Z, X[v, :])
|
||||||
else: psi2 = None
|
else: psi2 = None
|
||||||
num_data = psi1.shape[0]
|
num_data = psi1.shape[0]
|
||||||
|
|
||||||
if uncertain_inputs:
|
if uncertain_inputs:
|
||||||
if het_noise: psi2_beta = psi2 * (beta.flatten().reshape(num_data, 1, 1)).sum(0)
|
if het_noise: psi2_beta = psi2 * (beta.flatten().reshape(num_data, 1, 1)).sum(0)
|
||||||
else: psi2_beta = psi2.sum(0) * beta
|
else: psi2_beta = psi2 * beta
|
||||||
A = LmInv.dot(psi2_beta.dot(LmInv.T))
|
A = LmInv.dot(psi2_beta.dot(LmInv.T))
|
||||||
else:
|
else:
|
||||||
if het_noise: tmp = psi1 * (np.sqrt(beta.reshape(num_data, 1)))
|
if het_noise: tmp = psi1 * (np.sqrt(beta.reshape(num_data, 1)))
|
||||||
|
|
@ -331,7 +330,7 @@ class VarDTCMissingData(LatentFunctionInference):
|
||||||
dL_dpsi0_all[v] += dL_dpsi0
|
dL_dpsi0_all[v] += dL_dpsi0
|
||||||
dL_dpsi1_all[v, :] += dL_dpsi1
|
dL_dpsi1_all[v, :] += dL_dpsi1
|
||||||
if uncertain_inputs:
|
if uncertain_inputs:
|
||||||
dL_dpsi2_all[v, :] += dL_dpsi2
|
dL_dpsi2_all += dL_dpsi2
|
||||||
|
|
||||||
# log marginal likelihood
|
# log marginal likelihood
|
||||||
log_marginal += _compute_log_marginal_likelihood(likelihood, num_data, output_dim, beta, het_noise,
|
log_marginal += _compute_log_marginal_likelihood(likelihood, num_data, output_dim, beta, het_noise,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue