diff --git a/GPy/inference/latent_function_inference/dtc.py b/GPy/inference/latent_function_inference/dtc.py index 5ebc5e53..1a84da6b 100644 --- a/GPy/inference/latent_function_inference/dtc.py +++ b/GPy/inference/latent_function_inference/dtc.py @@ -19,19 +19,15 @@ class DTC(object): def __init__(self): self.const_jitter = 1e-6 - def inference(self, kern, X, Z, likelihood, Y): + def inference(self, kern, X, Z, likelihood, Y, Y_metadata=None): assert X_variance is None, "cannot use X_variance with DTC. Try varDTC." - #TODO: MAX! fix this! - from ...util.misc import param_to_array - Y = param_to_array(Y) - num_inducing, _ = Z.shape num_data, output_dim = Y.shape #make sure the noise is not hetero - beta = 1./np.squeeze(likelihood.variance) - if beta.size <1: + beta = 1./likelihood.gaussian_variance(Y_metadata) + if beta.size > 1: raise NotImplementedError, "no hetero noise with this implementation of DTC" Kmm = kern.K(Z) @@ -91,19 +87,15 @@ class vDTC(object): def __init__(self): self.const_jitter = 1e-6 - def inference(self, kern, X, X_variance, Z, likelihood, Y): + def inference(self, kern, X, X_variance, Z, likelihood, Y, Y_metadata): assert X_variance is None, "cannot use X_variance with DTC. Try varDTC." - #TODO: MAX! fix this! - from ...util.misc import param_to_array - Y = param_to_array(Y) - num_inducing, _ = Z.shape num_data, output_dim = Y.shape #make sure the noise is not hetero - beta = 1./np.squeeze(likelihood.variance) - if beta.size <1: + beta = 1./likelihood.gaussian_variance(Y_metadata) + if beta.size > 1: raise NotImplementedError, "no hetero noise with this implementation of DTC" Kmm = kern.K(Z) @@ -112,7 +104,7 @@ class vDTC(object): U = Knm Uy = np.dot(U.T,Y) - #factor Kmm + #factor Kmm Kmmi, L, Li, _ = pdinv(Kmm) # Compute A diff --git a/GPy/inference/latent_function_inference/fitc.py b/GPy/inference/latent_function_inference/fitc.py index c4147d06..de47e5d5 100644 --- a/GPy/inference/latent_function_inference/fitc.py +++ b/GPy/inference/latent_function_inference/fitc.py @@ -17,14 +17,14 @@ class FITC(object): """ const_jitter = 1e-6 - def inference(self, kern, X, Z, likelihood, Y): + def inference(self, kern, X, Z, likelihood, Y, Y_metadata=None): num_inducing, _ = Z.shape num_data, output_dim = Y.shape #make sure the noise is not hetero - sigma_n = np.squeeze(likelihood.variance) - if sigma_n.size <1: + sigma_n = likelihood.gaussian_variance(Y_metadata) + if sigma_n.size >1: raise NotImplementedError, "no hetero noise with this implementation of FITC" Kmm = kern.K(Z)