mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-04-27 22:06:22 +02:00
basic vardtc working
This commit is contained in:
parent
ac2d28e2fd
commit
8eaa0bbf8a
1 changed files with 17 additions and 32 deletions
|
|
@ -36,11 +36,11 @@ class VarDTC(LatentFunctionInference):
|
|||
return param_to_array(np.sum(np.square(Y)))
|
||||
|
||||
def __getstate__(self):
|
||||
# has to be overridden, as Cacher objects cannot be pickled.
|
||||
# has to be overridden, as Cacher objects cannot be pickled.
|
||||
return self.limit
|
||||
|
||||
def __setstate__(self, state):
|
||||
# has to be overridden, as Cacher objects cannot be pickled.
|
||||
# has to be overridden, as Cacher objects cannot be pickled.
|
||||
self.limit = state
|
||||
from ...util.caching import Cacher
|
||||
self.get_trYYT = Cacher(self._get_trYYT, self.limit)
|
||||
|
|
@ -62,20 +62,9 @@ class VarDTC(LatentFunctionInference):
|
|||
return Y * prec # TODO chache this, and make it effective
|
||||
|
||||
def inference(self, kern, X, Z, likelihood, Y, Y_metadata=None):
|
||||
if isinstance(X, VariationalPosterior):
|
||||
uncertain_inputs = True
|
||||
psi0 = kern.psi0(Z, X)
|
||||
psi1 = kern.psi1(Z, X)
|
||||
psi2 = kern.psi2(Z, X)
|
||||
else:
|
||||
uncertain_inputs = False
|
||||
psi0 = kern.Kdiag(X)
|
||||
psi1 = kern.K(X, Z)
|
||||
psi2 = None
|
||||
|
||||
#see whether we're using variational uncertain inputs
|
||||
|
||||
_, output_dim = Y.shape
|
||||
uncertain_inputs = isinstance(X, VariationalPosterior)
|
||||
|
||||
#see whether we've got a different noise variance for each datum
|
||||
beta = 1./np.fmax(likelihood.gaussian_variance(Y_metadata), 1e-6)
|
||||
|
|
@ -96,23 +85,21 @@ class VarDTC(LatentFunctionInference):
|
|||
diag.add(Kmm, self.const_jitter)
|
||||
Lm = jitchol(Kmm)
|
||||
|
||||
# The rather complex computations of A
|
||||
|
||||
# The rather complex computations of A, and the psi stats
|
||||
if uncertain_inputs:
|
||||
psi0 = kern.psi0(Z, X)
|
||||
psi1 = kern.psi1(Z, X)
|
||||
if het_noise:
|
||||
psi2_beta = psi2 * (beta.flatten().reshape(num_data, 1, 1)).sum(0)
|
||||
psi2_beta = np.sum([kern.psi2(Z,X[i:i+1,:]) * beta_i for i,beta_i in enumerate(beta)],0)
|
||||
else:
|
||||
psi2_beta = psi2.sum(0) * beta
|
||||
#if 0:
|
||||
# evals, evecs = linalg.eigh(psi2_beta)
|
||||
# clipped_evals = np.clip(evals, 0., 1e6) # TODO: make clipping configurable
|
||||
# if not np.array_equal(evals, clipped_evals):
|
||||
# pass # print evals
|
||||
# tmp = evecs * np.sqrt(clipped_evals)
|
||||
# tmp = tmp.T
|
||||
# no backsubstitution because of bound explosion on tr(A) if not...
|
||||
psi2_beta = kern.psi2(Z,X) * beta
|
||||
LmInv = dtrtri(Lm)
|
||||
A = LmInv.dot(psi2_beta.dot(LmInv.T))
|
||||
else:
|
||||
psi0 = kern.Kdiag(X)
|
||||
psi1 = kern.K(X, Z)
|
||||
psi2 = None
|
||||
if het_noise:
|
||||
tmp = psi1 * (np.sqrt(beta.reshape(num_data, 1)))
|
||||
else:
|
||||
|
|
@ -149,7 +136,7 @@ class VarDTC(LatentFunctionInference):
|
|||
log_marginal = _compute_log_marginal_likelihood(likelihood, num_data, output_dim, beta, het_noise,
|
||||
psi0, A, LB, trYYT, data_fit, VVT_factor)
|
||||
|
||||
#put the gradients in the right places
|
||||
#noise derivatives
|
||||
dL_dR = _compute_dL_dR(likelihood,
|
||||
het_noise, uncertain_inputs, LB,
|
||||
_LBi_Lmi_psi1Vf, DBi_plus_BiPBi, Lm, A,
|
||||
|
|
@ -158,6 +145,7 @@ class VarDTC(LatentFunctionInference):
|
|||
|
||||
dL_dthetaL = likelihood.exact_inference_gradients(dL_dR,Y_metadata)
|
||||
|
||||
#put the gradients in the right places
|
||||
if uncertain_inputs:
|
||||
grad_dict = {'dL_dKmm': dL_dKmm,
|
||||
'dL_dpsi0':dL_dpsi0,
|
||||
|
|
@ -203,11 +191,11 @@ class VarDTCMissingData(LatentFunctionInference):
|
|||
self._Y.limit = limit
|
||||
|
||||
def __getstate__(self):
|
||||
# has to be overridden, as Cacher objects cannot be pickled.
|
||||
# has to be overridden, as Cacher objects cannot be pickled.
|
||||
return self._Y.limit, self._inan
|
||||
|
||||
def __setstate__(self, state):
|
||||
# has to be overridden, as Cacher objects cannot be pickled.
|
||||
# has to be overridden, as Cacher objects cannot be pickled.
|
||||
from ...util.caching import Cacher
|
||||
self.limit = state[0]
|
||||
self._inan = state[1]
|
||||
|
|
@ -409,10 +397,7 @@ def _compute_dL_dpsi(num_inducing, num_data, output_dim, beta, Lm, VVT_factor, C
|
|||
dL_dpsi2 = None
|
||||
else:
|
||||
dL_dpsi2 = beta * dL_dpsi2_beta
|
||||
if uncertain_inputs:
|
||||
# repeat for each of the N psi_2 matrices
|
||||
dL_dpsi2 = np.repeat(dL_dpsi2[None, :, :], num_data, axis=0)
|
||||
else:
|
||||
if not uncertain_inputs:
|
||||
# subsume back into psi1 (==Kmn)
|
||||
dL_dpsi1 += 2.*np.dot(psi1, dL_dpsi2)
|
||||
dL_dpsi2 = None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue