feat: adding normalizer argument in constructor

This commit is contained in:
Kevin 2020-08-23 13:40:37 +02:00
parent 8ba67e00ca
commit 3f090751b3

View file

@ -8,6 +8,7 @@ from .. import kern
from .. import util from .. import util
from paramz import ObsAr from paramz import ObsAr
class GPCoregionalizedRegression(GP): class GPCoregionalizedRegression(GP):
""" """
Gaussian Process model for heteroscedastic multioutput regression Gaussian Process model for heteroscedastic multioutput regression
@ -35,6 +36,7 @@ class GPCoregionalizedRegression(GP):
X_list, X_list,
Y_list, Y_list,
kernel=None, kernel=None,
normalizer=None,
likelihoods_list=None, likelihoods_list=None,
name="GPCR", name="GPCR",
W_rank=1, W_rank=1,
@ -63,7 +65,12 @@ class GPCoregionalizedRegression(GP):
) )
super(GPCoregionalizedRegression, self).__init__( super(GPCoregionalizedRegression, self).__init__(
X, Y, kernel, likelihood, Y_metadata={"output_index": self.output_index} X,
Y,
kernel,
likelihood,
Y_metadata={"output_index": self.output_index},
normalizer=normalizer,
) )
def set_XY(self, X=None, Y=None): def set_XY(self, X=None, Y=None):
@ -74,6 +81,11 @@ class GPCoregionalizedRegression(GP):
self.update_model(False) self.update_model(False)
if Y is not None: if Y is not None:
if self.normalizer is not None:
self.normalizer.scale_by(Y)
self.Y_normalized = ObsAr(self.normalizer.normalize(Y))
self.Y = Y
else:
self.Y = ObsAr(Y) self.Y = ObsAr(Y)
self.Y_normalized = self.Y self.Y_normalized = self.Y
if X is not None: if X is not None: