This commit is contained in:
Kevin M Jablonka 2025-07-31 20:37:42 +00:00 committed by GitHub
commit c4f2af0a73
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -6,6 +6,8 @@ from ..core import GP
from .. import likelihoods
from .. import kern
from .. import util
from paramz import ObsAr
class GPCoregionalizedRegression(GP):
"""
@ -28,19 +30,70 @@ class GPCoregionalizedRegression(GP):
:param kernel_name: name of the kernel
:type kernel_name: string
"""
def __init__(self, X_list, Y_list, kernel=None, likelihoods_list=None, name='GPCR',W_rank=1,kernel_name='coreg'):
#Input and Output
X,Y,self.output_index = util.multioutput.build_XY(X_list,Y_list)
def __init__(
self,
X_list,
Y_list,
kernel=None,
normalizer=None,
likelihoods_list=None,
name="GPCR",
W_rank=1,
kernel_name="coreg",
):
# Input and Output
X, Y, self.output_index = util.multioutput.build_XY(X_list, Y_list)
Ny = len(Y_list)
#Kernel
# Kernel
if kernel is None:
kernel = kern.RBF(X.shape[1]-1)
kernel = util.multioutput.ICM(input_dim=X.shape[1]-1, num_outputs=Ny, kernel=kernel, W_rank=W_rank,name=kernel_name)
kernel = kern.RBF(X.shape[1] - 1)
#Likelihood
likelihood = util.multioutput.build_likelihood(Y_list,self.output_index,likelihoods_list)
kernel = util.multioutput.ICM(
input_dim=X.shape[1] - 1,
num_outputs=Ny,
kernel=kernel,
W_rank=W_rank,
name=kernel_name,
)
super(GPCoregionalizedRegression, self).__init__(X,Y,kernel,likelihood, Y_metadata={'output_index':self.output_index})
# Likelihood
likelihood = util.multioutput.build_likelihood(
Y_list, self.output_index, likelihoods_list
)
super(GPCoregionalizedRegression, self).__init__(
X,
Y,
kernel,
likelihood,
Y_metadata={"output_index": self.output_index},
normalizer=normalizer,
)
def set_XY(self, X=None, Y=None):
if isinstance(X, list):
X, _, self.output_index = util.multioutput.build_XY(X, None)
if isinstance(Y, list):
_, Y, self.output_index = util.multioutput.build_XY(Y, Y)
self.update_model(False)
if Y is not None:
if self.normalizer is not None:
self.normalizer.scale_by(Y)
self.Y_normalized = ObsAr(self.normalizer.normalize(Y))
self.Y = Y
else:
self.Y = ObsAr(Y)
self.Y_normalized = self.Y
if X is not None:
self.X = ObsAr(X)
self.Y_metadata = {
"output_index": self.output_index,
"trials": np.ones(self.output_index.shape),
}
self.update_model(True)