allow NCC init for inference_X

This commit is contained in:
Zhenwen Dai 2014-10-23 18:00:25 +01:00
parent fd0f0f2902
commit 131f4f3fc3

View file

@ -14,7 +14,7 @@ def inference_newX(model, Y_new, optimize=True):
class Inference_X(Model):
"""The class for inference of new X with given new Y. (do_test_latent)"""
def __init__(self, model, Y, name='inference_X'):
def __init__(self, model, Y, name='inference_X', init='L2'):
"""TODO: give comments"""
if np.isnan(Y).any():
assert Y.shape[0]==1, "The current implementation of inference X only support one data point at a time with missing data!"
@ -30,22 +30,25 @@ class Inference_X(Model):
self.variational_prior = model.variational_prior.copy()
self.Z = model.Z.copy()
self.Y = Y
self.X = self._init_X(model, Y)
self.X = self._init_X(model, Y, init=init)
self.compute_dL()
self.link_parameter(self.X)
def _init_X(self, model, Y_new):
def _init_X(self, model, Y_new, init='L2'):
# Initialize the new X by finding the nearest point in Y space.
Y = model.Y
if self.missing_data:
Y = Y[:,self.valid_dim]
Y_new = Y_new[:,self.valid_dim]
l2 = -2.*Y_new.dot(Y.T) + np.square(Y_new).sum(axis=1)[:,None]+ np.square(Y).sum(axis=1)[None,:]
dist = -2.*Y_new.dot(Y.T) + np.square(Y_new).sum(axis=1)[:,None]+ np.square(Y).sum(axis=1)[None,:]
else:
l2 = -2.*Y_new.dot(Y.T) + np.square(Y_new).sum(axis=1)[:,None]+ np.square(Y).sum(axis=1)[None,:]
idx = l2.argmin(axis=1)
if init=='L2':
dist = -2.*Y_new.dot(Y.T) + np.square(Y_new).sum(axis=1)[:,None]+ np.square(Y).sum(axis=1)[None,:]
elif init=='NCC':
dist = Y_new.dot(Y.T)
idx = dist.argmin(axis=1)
from ...models import SSGPLVM
from ...util.misc import param_to_array