feat: added set_xy for coregionalized model

This commit is contained in:
Kevin 2020-08-21 17:17:15 +02:00
parent ec20f9ed3a
commit 8ba67e00ca

View file

@ -6,6 +6,7 @@ from ..core import GP
from .. import likelihoods
from .. import kern
from .. import util
from paramz import ObsAr
class GPCoregionalizedRegression(GP):
"""
@ -28,19 +29,59 @@ class GPCoregionalizedRegression(GP):
:param kernel_name: name of the kernel
:type kernel_name: string
"""
def __init__(self, X_list, Y_list, kernel=None, likelihoods_list=None, name='GPCR',W_rank=1,kernel_name='coreg'):
#Input and Output
X,Y,self.output_index = util.multioutput.build_XY(X_list,Y_list)
def __init__(
self,
X_list,
Y_list,
kernel=None,
likelihoods_list=None,
name="GPCR",
W_rank=1,
kernel_name="coreg",
):
# Input and Output
X, Y, self.output_index = util.multioutput.build_XY(X_list, Y_list)
Ny = len(Y_list)
#Kernel
# Kernel
if kernel is None:
kernel = kern.RBF(X.shape[1]-1)
kernel = util.multioutput.ICM(input_dim=X.shape[1]-1, num_outputs=Ny, kernel=kernel, W_rank=W_rank,name=kernel_name)
kernel = kern.RBF(X.shape[1] - 1)
#Likelihood
likelihood = util.multioutput.build_likelihood(Y_list,self.output_index,likelihoods_list)
kernel = util.multioutput.ICM(
input_dim=X.shape[1] - 1,
num_outputs=Ny,
kernel=kernel,
W_rank=W_rank,
name=kernel_name,
)
super(GPCoregionalizedRegression, self).__init__(X,Y,kernel,likelihood, Y_metadata={'output_index':self.output_index})
# Likelihood
likelihood = util.multioutput.build_likelihood(
Y_list, self.output_index, likelihoods_list
)
super(GPCoregionalizedRegression, self).__init__(
X, Y, kernel, likelihood, Y_metadata={"output_index": self.output_index}
)
def set_XY(self, X=None, Y=None):
if isinstance(X, list):
X, _, self.output_index = util.multioutput.build_XY(X, None)
if isinstance(Y, list):
_, Y, self.output_index = util.multioutput.build_XY(Y, Y)
self.update_model(False)
if Y is not None:
self.Y = ObsAr(Y)
self.Y_normalized = self.Y
if X is not None:
self.X = ObsAr(X)
self.Y_metadata = {
"output_index": self.output_index,
"trials": np.ones(self.output_index.shape),
}
self.update_model(True)