diff --git a/GPy/inference/latent_function_inference/inference_X.py b/GPy/inference/latent_function_inference/inference_X.py index de3c475c..ff18680b 100644 --- a/GPy/inference/latent_function_inference/inference_X.py +++ b/GPy/inference/latent_function_inference/inference_X.py @@ -14,7 +14,7 @@ def inference_newX(model, Y_new, optimize=True): class Inference_X(Model): """The class for inference of new X with given new Y. (do_test_latent)""" - def __init__(self, model, Y, name='inference_X'): + def __init__(self, model, Y, name='inference_X', init='L2'): """TODO: give comments""" if np.isnan(Y).any(): assert Y.shape[0]==1, "The current implementation of inference X only support one data point at a time with missing data!" @@ -30,22 +30,25 @@ class Inference_X(Model): self.variational_prior = model.variational_prior.copy() self.Z = model.Z.copy() self.Y = Y - self.X = self._init_X(model, Y) + self.X = self._init_X(model, Y, init=init) self.compute_dL() self.link_parameter(self.X) - def _init_X(self, model, Y_new): + def _init_X(self, model, Y_new, init='L2'): # Initialize the new X by finding the nearest point in Y space. Y = model.Y if self.missing_data: Y = Y[:,self.valid_dim] Y_new = Y_new[:,self.valid_dim] - l2 = -2.*Y_new.dot(Y.T) + np.square(Y_new).sum(axis=1)[:,None]+ np.square(Y).sum(axis=1)[None,:] + dist = -2.*Y_new.dot(Y.T) + np.square(Y_new).sum(axis=1)[:,None]+ np.square(Y).sum(axis=1)[None,:] else: - l2 = -2.*Y_new.dot(Y.T) + np.square(Y_new).sum(axis=1)[:,None]+ np.square(Y).sum(axis=1)[None,:] - idx = l2.argmin(axis=1) + if init=='L2': + dist = -2.*Y_new.dot(Y.T) + np.square(Y_new).sum(axis=1)[:,None]+ np.square(Y).sum(axis=1)[None,:] + elif init=='NCC': + dist = Y_new.dot(Y.T) + idx = dist.argmin(axis=1) from ...models import SSGPLVM from ...util.misc import param_to_array