basic vardtc working

This commit is contained in:
James Hensman 2014-05-30 16:45:51 +01:00
parent ac2d28e2fd
commit 8eaa0bbf8a

View file

@ -36,11 +36,11 @@ class VarDTC(LatentFunctionInference):
return param_to_array(np.sum(np.square(Y)))
def __getstate__(self):
# has to be overridden, as Cacher objects cannot be pickled.
# has to be overridden, as Cacher objects cannot be pickled.
return self.limit
def __setstate__(self, state):
# has to be overridden, as Cacher objects cannot be pickled.
# has to be overridden, as Cacher objects cannot be pickled.
self.limit = state
from ...util.caching import Cacher
self.get_trYYT = Cacher(self._get_trYYT, self.limit)
@ -62,20 +62,9 @@ class VarDTC(LatentFunctionInference):
return Y * prec # TODO chache this, and make it effective
def inference(self, kern, X, Z, likelihood, Y, Y_metadata=None):
if isinstance(X, VariationalPosterior):
uncertain_inputs = True
psi0 = kern.psi0(Z, X)
psi1 = kern.psi1(Z, X)
psi2 = kern.psi2(Z, X)
else:
uncertain_inputs = False
psi0 = kern.Kdiag(X)
psi1 = kern.K(X, Z)
psi2 = None
#see whether we're using variational uncertain inputs
_, output_dim = Y.shape
uncertain_inputs = isinstance(X, VariationalPosterior)
#see whether we've got a different noise variance for each datum
beta = 1./np.fmax(likelihood.gaussian_variance(Y_metadata), 1e-6)
@ -96,23 +85,21 @@ class VarDTC(LatentFunctionInference):
diag.add(Kmm, self.const_jitter)
Lm = jitchol(Kmm)
# The rather complex computations of A
# The rather complex computations of A, and the psi stats
if uncertain_inputs:
psi0 = kern.psi0(Z, X)
psi1 = kern.psi1(Z, X)
if het_noise:
psi2_beta = psi2 * (beta.flatten().reshape(num_data, 1, 1)).sum(0)
psi2_beta = np.sum([kern.psi2(Z,X[i:i+1,:]) * beta_i for i,beta_i in enumerate(beta)],0)
else:
psi2_beta = psi2.sum(0) * beta
#if 0:
# evals, evecs = linalg.eigh(psi2_beta)
# clipped_evals = np.clip(evals, 0., 1e6) # TODO: make clipping configurable
# if not np.array_equal(evals, clipped_evals):
# pass # print evals
# tmp = evecs * np.sqrt(clipped_evals)
# tmp = tmp.T
# no backsubstitution because of bound explosion on tr(A) if not...
psi2_beta = kern.psi2(Z,X) * beta
LmInv = dtrtri(Lm)
A = LmInv.dot(psi2_beta.dot(LmInv.T))
else:
psi0 = kern.Kdiag(X)
psi1 = kern.K(X, Z)
psi2 = None
if het_noise:
tmp = psi1 * (np.sqrt(beta.reshape(num_data, 1)))
else:
@ -149,7 +136,7 @@ class VarDTC(LatentFunctionInference):
log_marginal = _compute_log_marginal_likelihood(likelihood, num_data, output_dim, beta, het_noise,
psi0, A, LB, trYYT, data_fit, VVT_factor)
#put the gradients in the right places
#noise derivatives
dL_dR = _compute_dL_dR(likelihood,
het_noise, uncertain_inputs, LB,
_LBi_Lmi_psi1Vf, DBi_plus_BiPBi, Lm, A,
@ -158,6 +145,7 @@ class VarDTC(LatentFunctionInference):
dL_dthetaL = likelihood.exact_inference_gradients(dL_dR,Y_metadata)
#put the gradients in the right places
if uncertain_inputs:
grad_dict = {'dL_dKmm': dL_dKmm,
'dL_dpsi0':dL_dpsi0,
@ -203,11 +191,11 @@ class VarDTCMissingData(LatentFunctionInference):
self._Y.limit = limit
def __getstate__(self):
# has to be overridden, as Cacher objects cannot be pickled.
# has to be overridden, as Cacher objects cannot be pickled.
return self._Y.limit, self._inan
def __setstate__(self, state):
# has to be overridden, as Cacher objects cannot be pickled.
# has to be overridden, as Cacher objects cannot be pickled.
from ...util.caching import Cacher
self.limit = state[0]
self._inan = state[1]
@ -409,10 +397,7 @@ def _compute_dL_dpsi(num_inducing, num_data, output_dim, beta, Lm, VVT_factor, C
dL_dpsi2 = None
else:
dL_dpsi2 = beta * dL_dpsi2_beta
if uncertain_inputs:
# repeat for each of the N psi_2 matrices
dL_dpsi2 = np.repeat(dL_dpsi2[None, :, :], num_data, axis=0)
else:
if not uncertain_inputs:
# subsume back into psi1 (==Kmn)
dL_dpsi1 += 2.*np.dot(psi1, dL_dpsi2)
dL_dpsi2 = None