Update GPLVM class to use metadata and output normalizers.

This commit is contained in:
Neil Lawrence 2021-05-19 09:45:21 +01:00 committed by Neil Lawrence
parent 8b098ec59b
commit 5c71aa45c7

View file

@ -14,7 +14,7 @@ class GPLVM(GP):
""" """
def __init__(self, Y, input_dim, init='PCA', X=None, kernel=None, name="gplvm"): def __init__(self, Y, input_dim, init='PCA', X=None, kernel=None, name="gplvm", Y_metadata=None, normalizer=False):
""" """
:param Y: observed data :param Y: observed data
@ -23,6 +23,11 @@ class GPLVM(GP):
:type input_dim: int :type input_dim: int
:param init: initialisation method for the latent space :param init: initialisation method for the latent space
:type init: 'PCA'|'random' :type init: 'PCA'|'random'
:param normalizer:
normalize the outputs Y.
If normalizer is True, we will normalize using Standardize.
If normalizer is False (the default), no normalization will be done.
:type normalizer: bool
""" """
if X is None: if X is None:
from ..util.initialization import initialize_latent from ..util.initialization import initialize_latent
@ -34,7 +39,7 @@ class GPLVM(GP):
likelihood = Gaussian() likelihood = Gaussian()
super(GPLVM, self).__init__(X, Y, kernel, likelihood, name='GPLVM') super(GPLVM, self).__init__(X, Y, kernel, likelihood, name='GPLVM', Y_metadata=Y_metadata, normalizer=normalizer)
self.X = Param('latent_mean', X) self.X = Param('latent_mean', X)
self.link_parameter(self.X, index=0) self.link_parameter(self.X, index=0)