mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
allow NCC init for inference_X
This commit is contained in:
parent
fd0f0f2902
commit
131f4f3fc3
1 changed files with 9 additions and 6 deletions
|
|
@ -14,7 +14,7 @@ def inference_newX(model, Y_new, optimize=True):
|
||||||
|
|
||||||
class Inference_X(Model):
|
class Inference_X(Model):
|
||||||
"""The class for inference of new X with given new Y. (do_test_latent)"""
|
"""The class for inference of new X with given new Y. (do_test_latent)"""
|
||||||
def __init__(self, model, Y, name='inference_X'):
|
def __init__(self, model, Y, name='inference_X', init='L2'):
|
||||||
"""TODO: give comments"""
|
"""TODO: give comments"""
|
||||||
if np.isnan(Y).any():
|
if np.isnan(Y).any():
|
||||||
assert Y.shape[0]==1, "The current implementation of inference X only support one data point at a time with missing data!"
|
assert Y.shape[0]==1, "The current implementation of inference X only support one data point at a time with missing data!"
|
||||||
|
|
@ -30,22 +30,25 @@ class Inference_X(Model):
|
||||||
self.variational_prior = model.variational_prior.copy()
|
self.variational_prior = model.variational_prior.copy()
|
||||||
self.Z = model.Z.copy()
|
self.Z = model.Z.copy()
|
||||||
self.Y = Y
|
self.Y = Y
|
||||||
self.X = self._init_X(model, Y)
|
self.X = self._init_X(model, Y, init=init)
|
||||||
self.compute_dL()
|
self.compute_dL()
|
||||||
|
|
||||||
self.link_parameter(self.X)
|
self.link_parameter(self.X)
|
||||||
|
|
||||||
def _init_X(self, model, Y_new):
|
def _init_X(self, model, Y_new, init='L2'):
|
||||||
# Initialize the new X by finding the nearest point in Y space.
|
# Initialize the new X by finding the nearest point in Y space.
|
||||||
|
|
||||||
Y = model.Y
|
Y = model.Y
|
||||||
if self.missing_data:
|
if self.missing_data:
|
||||||
Y = Y[:,self.valid_dim]
|
Y = Y[:,self.valid_dim]
|
||||||
Y_new = Y_new[:,self.valid_dim]
|
Y_new = Y_new[:,self.valid_dim]
|
||||||
l2 = -2.*Y_new.dot(Y.T) + np.square(Y_new).sum(axis=1)[:,None]+ np.square(Y).sum(axis=1)[None,:]
|
dist = -2.*Y_new.dot(Y.T) + np.square(Y_new).sum(axis=1)[:,None]+ np.square(Y).sum(axis=1)[None,:]
|
||||||
else:
|
else:
|
||||||
l2 = -2.*Y_new.dot(Y.T) + np.square(Y_new).sum(axis=1)[:,None]+ np.square(Y).sum(axis=1)[None,:]
|
if init=='L2':
|
||||||
idx = l2.argmin(axis=1)
|
dist = -2.*Y_new.dot(Y.T) + np.square(Y_new).sum(axis=1)[:,None]+ np.square(Y).sum(axis=1)[None,:]
|
||||||
|
elif init=='NCC':
|
||||||
|
dist = Y_new.dot(Y.T)
|
||||||
|
idx = dist.argmin(axis=1)
|
||||||
|
|
||||||
from ...models import SSGPLVM
|
from ...models import SSGPLVM
|
||||||
from ...util.misc import param_to_array
|
from ...util.misc import param_to_array
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue