From 1a59d527a3327847918099880ef2b30f70b15ca5 Mon Sep 17 00:00:00 2001 From: Zhenwen Dai Date: Mon, 13 Oct 2014 18:00:11 +0100 Subject: [PATCH] generalize the inference of X framework and with support of missing data --- .../latent_function_inference/inference_X.py | 91 +++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 GPy/inference/latent_function_inference/inference_X.py diff --git a/GPy/inference/latent_function_inference/inference_X.py b/GPy/inference/latent_function_inference/inference_X.py new file mode 100644 index 00000000..c28731c3 --- /dev/null +++ b/GPy/inference/latent_function_inference/inference_X.py @@ -0,0 +1,91 @@ +""" +""" +import numpy as np +from ...core import Model +from ...core.parameterization import variational + +def inference_newX(model, Y_new, optimize=True): + infr_m = Inference_X(model, Y_new) + + if optimize: + infr_m.optimize() + + return infr_m.X, infr_m + +class Inference_X(Model): + """The class for inference of new X with given new Y. (do_test_latent)""" + def __init__(self, model, Y, name='inference_X'): + """TODO: give comments""" + if np.isnan(Y).any(): + assert Y.shape[0]==1, "The current implementation of inference X only support one data point at a time with missing data!" + self.missing_data = True + self.valid_dim = np.logical_not(np.isnan(Y[0])) + else: + self.missing_data = False + super(Inference_X, self).__init__(name) + self.likelihood = model.likelihood.copy() + self.kern = model.kern.copy() + from copy import deepcopy + self.posterior = deepcopy(model.posterior) + self.variational_prior = model.variational_prior.copy() + self.Z = model.Z.copy() + self.Y = Y + self.X = self._init_X(model, Y) + self.compute_dL() + + self.link_parameter(self.X) + + def _init_X(self, model, Y_new): + # Initialize the new X by finding the nearest point in Y space. + + Y = model.Y + if self.missing_data: + Y = Y[:,self.valid_dim] + Y_new = Y_new[:,self.valid_dim] + l2 = -2.*Y_new.dot(Y.T) + np.square(Y_new).sum(axis=1)[:,None]+ np.square(Y).sum(axis=1)[None,:] + else: + l2 = -2.*Y_new.dot(Y.T) + np.square(Y_new).sum(axis=1)[:,None]+ np.square(Y).sum(axis=1)[None,:] + idx = l2.argmin(axis=1) + + from ...models import SSGPLVM + from ...util.misc import param_to_array + if isinstance(model, SSGPLVM): + X = variational.SpikeAndSlabPosterior(param_to_array(model.X.mean[idx]), param_to_array(model.X.variance[idx]), param_to_array(model.X.gamma[idx])) + if model.group_spike: + [X.gamma[:,i].tie_together() for i in xrange(X.gamma.shape[1])] # Tie columns together + else: + X = variational.NormalPosterior(param_to_array(model.X.mean[idx]), param_to_array(model.X.variance[idx])) + + return X + + def compute_dL(self): + # Common computation + beta = 1./np.fmax(self.likelihood.variance, 1e-6) + output_dim = self.Y.shape[-1] + wv = self.posterior.woodbury_vector + if self.missing_data: + wv = wv[:,self.valid_dim] + output_dim = self.valid_dim.sum() + self.dL_dpsi2 = beta*(output_dim*self.posterior.woodbury_inv - np.einsum('md,od->mo',wv, wv))/2. + self.dL_dpsi1 = beta*np.dot(self.Y[:,self.valid_dim], wv.T) + self.dL_dpsi0 = -output_dim * beta/2.* np.ones(self.Y.shape[0]) + else: + self.dL_dpsi2 = beta*(output_dim*self.posterior.woodbury_inv - np.einsum('md,od->mo',wv, wv))/2. + self.dL_dpsi1 = beta*np.dot(self.Y, wv.T) + self.dL_dpsi0 = -output_dim * beta/2.* np.ones(self.Y.shape[0]) + + def parameters_changed(self): + psi0 = self.kern.psi0(self.Z, self.X) + psi1 = self.kern.psi1(self.Z, self.X) + psi2 = self.kern.psi2(self.Z, self.X) + + self._log_marginal_likelihood = (self.dL_dpsi2*psi2).sum()+(self.dL_dpsi1*psi1).sum()+(self.dL_dpsi0*psi0).sum() + X_grad = self.kern.gradients_qX_expectations(variational_posterior=self.X, Z=self.Z, dL_dpsi0=self.dL_dpsi0, dL_dpsi1=self.dL_dpsi1, dL_dpsi2=self.dL_dpsi2) + self.X.set_gradients(X_grad) + + self._log_marginal_likelihood -= self.variational_prior.KL_divergence(self.X) + self.variational_prior.update_gradients_KL(self.X) + + def log_likelihood(self): + return self._log_marginal_likelihood +