mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-15 06:52:39 +02:00
Add mean function functionality to dtc inference method
This commit is contained in:
parent
754c67f71d
commit
a24a9b3edc
5 changed files with 45 additions and 18 deletions
|
|
@ -74,11 +74,16 @@ class SparseGP(GP):
|
||||||
if trigger_update: self.update_model(True)
|
if trigger_update: self.update_model(True)
|
||||||
|
|
||||||
def parameters_changed(self):
|
def parameters_changed(self):
|
||||||
self.posterior, self._log_marginal_likelihood, self.grad_dict = self.inference_method.inference(self.kern, self.X, self.Z, self.likelihood, self.Y, self.Y_metadata)
|
self.posterior, self._log_marginal_likelihood, self.grad_dict = \
|
||||||
|
self.inference_method.inference(self.kern, self.X, self.Z, self.likelihood,
|
||||||
|
self.Y, Y_metadata=self.Y_metadata,
|
||||||
|
mean_function=self.mean_function)
|
||||||
self._update_gradients()
|
self._update_gradients()
|
||||||
|
|
||||||
def _update_gradients(self):
|
def _update_gradients(self):
|
||||||
self.likelihood.update_gradients(self.grad_dict['dL_dthetaL'])
|
self.likelihood.update_gradients(self.grad_dict['dL_dthetaL'])
|
||||||
|
if self.mean_function is not None:
|
||||||
|
self.mean_function.update_gradients(self.grad_dict['dL_dm'], self.X)
|
||||||
|
|
||||||
if isinstance(self.X, VariationalPosterior):
|
if isinstance(self.X, VariationalPosterior):
|
||||||
#gradients wrt kernel
|
#gradients wrt kernel
|
||||||
|
|
@ -112,4 +117,3 @@ class SparseGP(GP):
|
||||||
self.Z.gradient = self.kern.gradients_X(self.grad_dict['dL_dKmm'], self.Z)
|
self.Z.gradient = self.kern.gradients_X(self.grad_dict['dL_dKmm'], self.Z)
|
||||||
self.Z.gradient += self.kern.gradients_X(self.grad_dict['dL_dKnm'].T, self.Z, self.X)
|
self.Z.gradient += self.kern.gradients_X(self.grad_dict['dL_dKnm'].T, self.Z, self.X)
|
||||||
self._Zgrad = self.Z.gradient.copy()
|
self._Zgrad = self.Z.gradient.copy()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,9 @@ class SparseGP_MPI(SparseGP):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, X, Y, Z, kernel, likelihood, variational_prior=None, inference_method=None, name='sparse gp', Y_metadata=None, mpi_comm=None, normalizer=False):
|
def __init__(self, X, Y, Z, kernel, likelihood, variational_prior=None,
|
||||||
|
mean_function=None, inference_method=None, name='sparse gp',
|
||||||
|
Y_metadata=None, mpi_comm=None, normalizer=False):
|
||||||
self._IN_OPTIMIZATION_ = False
|
self._IN_OPTIMIZATION_ = False
|
||||||
if mpi_comm != None:
|
if mpi_comm != None:
|
||||||
if inference_method is None:
|
if inference_method is None:
|
||||||
|
|
@ -42,12 +44,12 @@ class SparseGP_MPI(SparseGP):
|
||||||
else:
|
else:
|
||||||
assert isinstance(inference_method, VarDTC_minibatch), 'inference_method has to support MPI!'
|
assert isinstance(inference_method, VarDTC_minibatch), 'inference_method has to support MPI!'
|
||||||
|
|
||||||
super(SparseGP_MPI, self).__init__(X, Y, Z, kernel, likelihood, inference_method=inference_method, name=name, Y_metadata=Y_metadata, normalizer=normalizer)
|
super(SparseGP_MPI, self).__init__(X, Y, Z, kernel, likelihood, inference_method=inference_method, mean_function=mean_function, name=name, Y_metadata=Y_metadata, normalizer=normalizer)
|
||||||
self.update_model(False)
|
self.update_model(False)
|
||||||
|
|
||||||
if variational_prior is not None:
|
if variational_prior is not None:
|
||||||
self.link_parameter(variational_prior)
|
self.link_parameter(variational_prior)
|
||||||
|
|
||||||
self.mpi_comm = mpi_comm
|
self.mpi_comm = mpi_comm
|
||||||
# Manage the data (Y) division
|
# Manage the data (Y) division
|
||||||
if mpi_comm != None:
|
if mpi_comm != None:
|
||||||
|
|
@ -118,4 +120,3 @@ class SparseGP_MPI(SparseGP):
|
||||||
update_gradients(self, mpi_comm=self.mpi_comm)
|
update_gradients(self, mpi_comm=self.mpi_comm)
|
||||||
else:
|
else:
|
||||||
super(SparseGP_MPI,self).parameters_changed()
|
super(SparseGP_MPI,self).parameters_changed()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -61,16 +61,20 @@ class VarDTC(LatentFunctionInference):
|
||||||
return jitchol(tdot(Y))
|
return jitchol(tdot(Y))
|
||||||
|
|
||||||
def get_VVTfactor(self, Y, prec):
|
def get_VVTfactor(self, Y, prec):
|
||||||
return Y * prec # TODO chache this, and make it effective
|
return Y * prec # TODO cache this, and make it effective
|
||||||
|
|
||||||
def inference(self, kern, X, Z, likelihood, Y, Y_metadata=None, mean_function=None, precision=None, Lm=None, dL_dKmm=None, psi0=None, psi1=None, psi2=None, Z_tilde=None):
|
def inference(self, kern, X, Z, likelihood, Y, Y_metadata=None, mean_function=None, precision=None, Lm=None, dL_dKmm=None, psi0=None, psi1=None, psi2=None, Z_tilde=None):
|
||||||
assert mean_function is None, "inference with a mean function not implemented"
|
|
||||||
|
|
||||||
num_data, output_dim = Y.shape
|
num_data, output_dim = Y.shape
|
||||||
num_inducing = Z.shape[0]
|
num_inducing = Z.shape[0]
|
||||||
|
|
||||||
uncertain_inputs = isinstance(X, VariationalPosterior)
|
uncertain_inputs = isinstance(X, VariationalPosterior)
|
||||||
|
|
||||||
|
if mean_function is not None:
|
||||||
|
mean = mean_function.f(X)
|
||||||
|
else:
|
||||||
|
mean = 0
|
||||||
|
|
||||||
if precision is None:
|
if precision is None:
|
||||||
#assume Gaussian likelihood
|
#assume Gaussian likelihood
|
||||||
precision = 1./np.fmax(likelihood.gaussian_variance(Y_metadata), self.const_jitter)
|
precision = 1./np.fmax(likelihood.gaussian_variance(Y_metadata), self.const_jitter)
|
||||||
|
|
@ -78,10 +82,11 @@ class VarDTC(LatentFunctionInference):
|
||||||
if precision.ndim == 1:
|
if precision.ndim == 1:
|
||||||
precision = precision[:, None]
|
precision = precision[:, None]
|
||||||
het_noise = precision.size > 1
|
het_noise = precision.size > 1
|
||||||
|
if (het_noise or uncertain_inputs) and mean_function is not None:
|
||||||
|
raise ValueError('Mean function not implemented with uncertain inputs or heteroscedasticity')
|
||||||
|
|
||||||
VVT_factor = precision*Y
|
VVT_factor = precision*(Y-mean)
|
||||||
#VVT_factor = precision*Y
|
trYYT = self.get_trYYT(Y-mean)
|
||||||
trYYT = self.get_trYYT(Y)
|
|
||||||
|
|
||||||
# kernel computations, using BGPLVM notation
|
# kernel computations, using BGPLVM notation
|
||||||
if Lm is None:
|
if Lm is None:
|
||||||
|
|
@ -128,14 +133,18 @@ class VarDTC(LatentFunctionInference):
|
||||||
# factor B
|
# factor B
|
||||||
B = np.eye(num_inducing) + A
|
B = np.eye(num_inducing) + A
|
||||||
LB = jitchol(B)
|
LB = jitchol(B)
|
||||||
psi1Vf = np.dot(psi1.T, VVT_factor)
|
|
||||||
# back substutue C into psi1Vf
|
# back substutue C into psi1Vf
|
||||||
tmp, _ = dtrtrs(Lm, psi1Vf, lower=1, trans=0)
|
tmp, _ = dtrtrs(Lm, psi1.T, lower=1, trans=0)
|
||||||
_LBi_Lmi_psi1Vf, _ = dtrtrs(LB, tmp, lower=1, trans=0)
|
_LBi_Lmi_psi1, _ = dtrtrs(LB, tmp, lower=1, trans=0)
|
||||||
|
_LBi_Lmi_psi1Vf = np.dot(_LBi_Lmi_psi1, VVT_factor)
|
||||||
tmp, _ = dtrtrs(LB, _LBi_Lmi_psi1Vf, lower=1, trans=1)
|
tmp, _ = dtrtrs(LB, _LBi_Lmi_psi1Vf, lower=1, trans=1)
|
||||||
Cpsi1Vf, _ = dtrtrs(Lm, tmp, lower=1, trans=1)
|
Cpsi1Vf, _ = dtrtrs(Lm, tmp, lower=1, trans=1)
|
||||||
|
|
||||||
# data fit and derivative of L w.r.t. Kmm
|
# data fit and derivative of L w.r.t. Kmm
|
||||||
|
dL_dm = -np.dot((_LBi_Lmi_psi1.T.dot(_LBi_Lmi_psi1))
|
||||||
|
- np.eye(Y.shape[0]), VVT_factor)
|
||||||
|
|
||||||
delit = tdot(_LBi_Lmi_psi1Vf)
|
delit = tdot(_LBi_Lmi_psi1Vf)
|
||||||
data_fit = np.trace(delit)
|
data_fit = np.trace(delit)
|
||||||
DBi_plus_BiPBi = backsub_both_sides(LB, output_dim * np.eye(num_inducing) + delit)
|
DBi_plus_BiPBi = backsub_both_sides(LB, output_dim * np.eye(num_inducing) + delit)
|
||||||
|
|
@ -181,7 +190,8 @@ class VarDTC(LatentFunctionInference):
|
||||||
grad_dict = {'dL_dKmm': dL_dKmm,
|
grad_dict = {'dL_dKmm': dL_dKmm,
|
||||||
'dL_dKdiag':dL_dpsi0,
|
'dL_dKdiag':dL_dpsi0,
|
||||||
'dL_dKnm':dL_dpsi1,
|
'dL_dKnm':dL_dpsi1,
|
||||||
'dL_dthetaL':dL_dthetaL}
|
'dL_dthetaL':dL_dthetaL,
|
||||||
|
'dL_dm':dL_dm}
|
||||||
|
|
||||||
#get sufficient things for posterior prediction
|
#get sufficient things for posterior prediction
|
||||||
#TODO: do we really want to do this in the loop?
|
#TODO: do we really want to do this in the loop?
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ class SparseGPRegression(SparseGP_MPI):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, X, Y, kernel=None, Z=None, num_inducing=10, X_variance=None, normalizer=None, mpi_comm=None, name='sparse_gp'):
|
def __init__(self, X, Y, kernel=None, Z=None, num_inducing=10, X_variance=None, mean_function=None, normalizer=None, mpi_comm=None, name='sparse_gp'):
|
||||||
num_data, input_dim = X.shape
|
num_data, input_dim = X.shape
|
||||||
|
|
||||||
# kern defaults to rbf (plus white for stability)
|
# kern defaults to rbf (plus white for stability)
|
||||||
|
|
@ -55,7 +55,8 @@ class SparseGPRegression(SparseGP_MPI):
|
||||||
else:
|
else:
|
||||||
infr = VarDTC()
|
infr = VarDTC()
|
||||||
|
|
||||||
SparseGP_MPI.__init__(self, X, Y, Z, kernel, likelihood, inference_method=infr, normalizer=normalizer, mpi_comm=mpi_comm, name=name)
|
SparseGP_MPI.__init__(self, X, Y, Z, kernel, likelihood, mean_function=mean_function,
|
||||||
|
inference_method=infr, normalizer=normalizer, mpi_comm=mpi_comm, name=name)
|
||||||
|
|
||||||
def parameters_changed(self):
|
def parameters_changed(self):
|
||||||
from ..inference.latent_function_inference.var_dtc_parallel import update_gradients_sparsegp,VarDTC_minibatch
|
from ..inference.latent_function_inference.var_dtc_parallel import update_gradients_sparsegp,VarDTC_minibatch
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,7 @@ class InferenceXTestCase(unittest.TestCase):
|
||||||
m.optimize()
|
m.optimize()
|
||||||
x, mi = m.infer_newX(m.Y, optimize=True)
|
x, mi = m.infer_newX(m.Y, optimize=True)
|
||||||
np.testing.assert_array_almost_equal(m.X, mi.X, decimal=2)
|
np.testing.assert_array_almost_equal(m.X, mi.X, decimal=2)
|
||||||
|
|
||||||
class InferenceGPEP(unittest.TestCase):
|
class InferenceGPEP(unittest.TestCase):
|
||||||
|
|
||||||
def genData(self):
|
def genData(self):
|
||||||
|
|
@ -132,6 +133,16 @@ class InferenceGPEP(unittest.TestCase):
|
||||||
np.sum(p._woodbury_vector - p0._woodbury_vector),
|
np.sum(p._woodbury_vector - p0._woodbury_vector),
|
||||||
np.sum(p.woodbury_inv - p0.woodbury_inv)])) < 1e6)
|
np.sum(p.woodbury_inv - p0.woodbury_inv)])) < 1e6)
|
||||||
|
|
||||||
|
class VarDtcTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_var_dtc_inference_with_mean(self):
|
||||||
|
""" Check dL_dm in var_dtc is calculated correctly"""
|
||||||
|
np.random.seed(1)
|
||||||
|
x = np.linspace(0.,2*np.pi,100)[:,None]
|
||||||
|
y = -np.cos(x)+np.random.randn(*x.shape)*0.3+1
|
||||||
|
m = GPy.models.SparseGPRegression(x,y, mean_function=GPy.mappings.Linear(input_dim=1, output_dim=1))
|
||||||
|
self.assertTrue(m.checkgrad())
|
||||||
|
|
||||||
|
|
||||||
class HMCSamplerTest(unittest.TestCase):
|
class HMCSamplerTest(unittest.TestCase):
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue