From 1840b7e6b881a47a137b94c53cbfcfa41920b9fc Mon Sep 17 00:00:00 2001 From: Zhenwen Dai Date: Mon, 3 Nov 2014 16:04:15 +0000 Subject: [PATCH] extend inference X for all gp models --- GPy/core/gp.py | 16 +++++ .../latent_function_inference/inferenceX.py | 69 +++++++++++++------ GPy/models/bayesian_gplvm.py | 16 ----- GPy/models/sparse_gplvm.py | 3 +- GPy/testing/inference_tests.py | 11 +++ 5 files changed, 78 insertions(+), 37 deletions(-) diff --git a/GPy/core/gp.py b/GPy/core/gp.py index 0e93cd99..e0dfde0c 100644 --- a/GPy/core/gp.py +++ b/GPy/core/gp.py @@ -354,3 +354,19 @@ class GP(Model): print "KeyboardInterrupt caught, calling on_optimization_end() to round things up" self.inference_method.on_optimization_end() raise + + def infer_newX(self, Y_new, optimize=True, ): + """ + Infer the distribution of X for the new observed data *Y_new*. + + :param model: the GPy model used in inference + :type model: GPy.core.Model + :param Y_new: the new observed data for inference + :type Y_new: numpy.ndarray + :param optimize: whether to optimize the location of new X (True by default) + :type optimize: boolean + :return: a tuple containing the estimated posterior distribution of X and the model that optimize X + :rtype: (GPy.core.parameterization.variational.VariationalPosterior, GPy.core.Model) + """ + from ..inference.latent_function_inference.inferenceX import infer_newX + return infer_newX(self, Y_new, optimize=optimize) diff --git a/GPy/inference/latent_function_inference/inferenceX.py b/GPy/inference/latent_function_inference/inferenceX.py index 7480c910..66fbcd4d 100644 --- a/GPy/inference/latent_function_inference/inferenceX.py +++ b/GPy/inference/latent_function_inference/inferenceX.py @@ -51,8 +51,18 @@ class InferenceX(Model): self.kern.GPU(True) from copy import deepcopy self.posterior = deepcopy(model.posterior) - self.variational_prior = model.variational_prior.copy() - self.Z = model.Z.copy() + if hasattr(model, 'variational_prior'): + self.uncertain_input = True + self.variational_prior = model.variational_prior.copy() + else: + self.uncertain_input = False + if hasattr(model, 'inducing_inputs'): + self.sparse_gp = True + self.Z = model.Z.copy() + else: + self.sparse_gp = False + self.uncertain_input = False + self.Z = model.X.copy() self.Y = Y self.X = self._init_X(model, Y, init=init) self.compute_dL() @@ -72,6 +82,8 @@ class InferenceX(Model): dist = -2.*Y_new.dot(Y.T) + np.square(Y_new).sum(axis=1)[:,None]+ np.square(Y).sum(axis=1)[None,:] elif init=='NCC': dist = Y_new.dot(Y.T) + elif init=='rand': + dist = np.random.rand(Y_new.shape[0],Y.shape[0]) idx = dist.argmin(axis=1) from ...models import SSGPLVM @@ -81,7 +93,11 @@ class InferenceX(Model): if model.group_spike: X.gamma.fix() else: - X = variational.NormalPosterior(param_to_array(model.X.mean[idx]), param_to_array(model.X.variance[idx])) + if self.uncertain_input and self.sparse_gp: + X = variational.NormalPosterior(param_to_array(model.X.mean[idx]), param_to_array(model.X.variance[idx])) + else: + from ...core import Param + X = Param('latent mean',param_to_array(model.X[idx]).copy()) return X @@ -99,29 +115,42 @@ class InferenceX(Model): else: self.dL_dpsi2 = beta*(output_dim*self.posterior.woodbury_inv - np.einsum('md,od->mo',wv, wv))/2. self.dL_dpsi1 = beta*np.dot(self.Y, wv.T) - self.dL_dpsi0 = -beta/2.* np.ones(self.Y.shape[0]) #self.dL_dpsi0[:] = 0 + self.dL_dpsi0 = -beta/2.* np.ones(self.Y.shape[0]) def parameters_changed(self): - psi0 = self.kern.psi0(self.Z, self.X) - psi1 = self.kern.psi1(self.Z, self.X) - psi2 = self.kern.psi2(self.Z, self.X) + if self.uncertain_input: + psi0 = self.kern.psi0(self.Z, self.X) + psi1 = self.kern.psi1(self.Z, self.X) + psi2 = self.kern.psi2(self.Z, self.X) + else: + psi0 = self.kern.Kdiag(self.X) + psi1 = self.kern.K(self.X, self.Z) + psi2 = np.dot(psi1.T,psi1) self._log_marginal_likelihood = (self.dL_dpsi2*psi2).sum()+(self.dL_dpsi1*psi1).sum()+(self.dL_dpsi0*psi0).sum() - X_grad = self.kern.gradients_qX_expectations(variational_posterior=self.X, Z=self.Z, dL_dpsi0=self.dL_dpsi0, dL_dpsi1=self.dL_dpsi1, dL_dpsi2=self.dL_dpsi2) - self.X.set_gradients(X_grad) - from ...core.parameterization.variational import SpikeAndSlabPrior - if isinstance(self.variational_prior, SpikeAndSlabPrior): - # Update Log-likelihood - KL_div = self.variational_prior.KL_divergence(self.X, N=self.Y.shape[0]) - # update for the KL divergence - self.variational_prior.update_gradients_KL(self.X, N=self.Y.shape[0]) + if self.uncertain_input: + X_grad = self.kern.gradients_qX_expectations(variational_posterior=self.X, Z=self.Z, dL_dpsi0=self.dL_dpsi0, dL_dpsi1=self.dL_dpsi1, dL_dpsi2=self.dL_dpsi2) + self.X.set_gradients(X_grad) else: - # Update Log-likelihood - KL_div = self.variational_prior.KL_divergence(self.X) - # update for the KL divergence - self.variational_prior.update_gradients_KL(self.X) - self._log_marginal_likelihood += -KL_div + dL_dpsi1 = self.dL_dpsi1 + 2.*np.dot(psi1,self.dL_dpsi2) + X_grad = self.kern.gradients_X_diag(self.dL_dpsi0, self.X) + X_grad += self.kern.gradients_X(dL_dpsi1, self.X, self.Z) + self.X.gradient = X_grad + + if self.uncertain_input: + from ...core.parameterization.variational import SpikeAndSlabPrior + if isinstance(self.variational_prior, SpikeAndSlabPrior): + # Update Log-likelihood + KL_div = self.variational_prior.KL_divergence(self.X, N=self.Y.shape[0]) + # update for the KL divergence + self.variational_prior.update_gradients_KL(self.X, N=self.Y.shape[0]) + else: + # Update Log-likelihood + KL_div = self.variational_prior.KL_divergence(self.X) + # update for the KL divergence + self.variational_prior.update_gradients_KL(self.X) + self._log_marginal_likelihood += -KL_div def log_likelihood(self): return self._log_marginal_likelihood diff --git a/GPy/models/bayesian_gplvm.py b/GPy/models/bayesian_gplvm.py index 629522d1..73fd6f8f 100644 --- a/GPy/models/bayesian_gplvm.py +++ b/GPy/models/bayesian_gplvm.py @@ -141,22 +141,6 @@ class BayesianGPLVM(SparseGP_MPI): resolution, ax, marker, s, fignum, plot_inducing, legend, plot_limits, aspect, updates, predict_kwargs, imshow_kwargs) - - def infer_newX(self, Y_new, optimize=True, ): - """ - Infer the distribution of X for the new observed data *Y_new*. - - :param model: the GPy model used in inference - :type model: GPy.core.Model - :param Y_new: the new observed data for inference - :type Y_new: numpy.ndarray - :param optimize: whether to optimize the location of new X (True by default) - :type optimize: boolean - :return: a tuple containing the estimated posterior distribution of X and the model that optimize X - :rtype: (GPy.core.parameterization.variational.VariationalPosterior, GPy.core.Model) - """ - from ..inference.latent_function_inference.inferenceX import infer_newX - return infer_newX(self, Y_new, optimize=optimize) def do_test_latents(self, Y): """ diff --git a/GPy/models/sparse_gplvm.py b/GPy/models/sparse_gplvm.py index 251103f4..d1ad5884 100644 --- a/GPy/models/sparse_gplvm.py +++ b/GPy/models/sparse_gplvm.py @@ -26,7 +26,8 @@ class SparseGPLVM(SparseGPRegression): def parameters_changed(self): super(SparseGPLVM, self).parameters_changed() - self.X.gradient = self.kern.gradients_X(self.grad_dict['dL_dKnm'], self.X, self.Z) + self.X.gradient = self.kern.gradients_X_diag(self.grad_dict['dL_dKdiag'], self.X) + self.X.gradient += self.kern.gradients_X(self.grad_dict['dL_dKnm'], self.X, self.Z) def plot_latent(self, labels=None, which_indices=None, resolution=50, ax=None, marker='o', s=40, diff --git a/GPy/testing/inference_tests.py b/GPy/testing/inference_tests.py index c7efa821..fd81022a 100644 --- a/GPy/testing/inference_tests.py +++ b/GPy/testing/inference_tests.py @@ -65,6 +65,17 @@ class InferenceXTestCase(unittest.TestCase): self.assertTrue(np.allclose(m.X.mean, mi.X.mean)) self.assertTrue(np.allclose(m.X.variance, mi.X.variance)) + + def test_inferenceX_GPLVM(self): + Ys = self.genData() + m = GPy.models.GPLVM(Ys[0],3,kernel=GPy.kern.RBF(3,ARD=True)) + + x,mi = m.infer_newX(m.Y, optimize=False) + self.assertTrue(mi.checkgrad()) + +# m.optimize(max_iters=10000) +# x,mi = m.infer_newX(m.Y) +# self.assertTrue(np.allclose(m.X, x)) if __name__ == "__main__":