This commit is contained in:
Kevin M Jablonka 2025-07-31 20:37:42 +00:00 committed by GitHub
commit c4f2af0a73
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -6,6 +6,8 @@ from ..core import GP
from .. import likelihoods from .. import likelihoods
from .. import kern from .. import kern
from .. import util from .. import util
from paramz import ObsAr
class GPCoregionalizedRegression(GP): class GPCoregionalizedRegression(GP):
""" """
@ -28,19 +30,70 @@ class GPCoregionalizedRegression(GP):
:param kernel_name: name of the kernel :param kernel_name: name of the kernel
:type kernel_name: string :type kernel_name: string
""" """
def __init__(self, X_list, Y_list, kernel=None, likelihoods_list=None, name='GPCR',W_rank=1,kernel_name='coreg'):
#Input and Output def __init__(
X,Y,self.output_index = util.multioutput.build_XY(X_list,Y_list) self,
X_list,
Y_list,
kernel=None,
normalizer=None,
likelihoods_list=None,
name="GPCR",
W_rank=1,
kernel_name="coreg",
):
# Input and Output
X, Y, self.output_index = util.multioutput.build_XY(X_list, Y_list)
Ny = len(Y_list) Ny = len(Y_list)
#Kernel # Kernel
if kernel is None: if kernel is None:
kernel = kern.RBF(X.shape[1]-1) kernel = kern.RBF(X.shape[1] - 1)
kernel = util.multioutput.ICM(input_dim=X.shape[1]-1, num_outputs=Ny, kernel=kernel, W_rank=W_rank,name=kernel_name)
#Likelihood kernel = util.multioutput.ICM(
likelihood = util.multioutput.build_likelihood(Y_list,self.output_index,likelihoods_list) input_dim=X.shape[1] - 1,
num_outputs=Ny,
kernel=kernel,
W_rank=W_rank,
name=kernel_name,
)
super(GPCoregionalizedRegression, self).__init__(X,Y,kernel,likelihood, Y_metadata={'output_index':self.output_index}) # Likelihood
likelihood = util.multioutput.build_likelihood(
Y_list, self.output_index, likelihoods_list
)
super(GPCoregionalizedRegression, self).__init__(
X,
Y,
kernel,
likelihood,
Y_metadata={"output_index": self.output_index},
normalizer=normalizer,
)
def set_XY(self, X=None, Y=None):
if isinstance(X, list):
X, _, self.output_index = util.multioutput.build_XY(X, None)
if isinstance(Y, list):
_, Y, self.output_index = util.multioutput.build_XY(Y, Y)
self.update_model(False)
if Y is not None:
if self.normalizer is not None:
self.normalizer.scale_by(Y)
self.Y_normalized = ObsAr(self.normalizer.normalize(Y))
self.Y = Y
else:
self.Y = ObsAr(Y)
self.Y_normalized = self.Y
if X is not None:
self.X = ObsAr(X)
self.Y_metadata = {
"output_index": self.output_index,
"trials": np.ones(self.output_index.shape),
}
self.update_model(True)