metadata passing in fitc

This commit is contained in:
James Hensman 2014-03-20 09:25:05 +00:00
parent c353ac67e6
commit ff88845f99
2 changed files with 10 additions and 18 deletions

View file

@ -19,19 +19,15 @@ class DTC(object):
def __init__(self): def __init__(self):
self.const_jitter = 1e-6 self.const_jitter = 1e-6
def inference(self, kern, X, Z, likelihood, Y): def inference(self, kern, X, Z, likelihood, Y, Y_metadata=None):
assert X_variance is None, "cannot use X_variance with DTC. Try varDTC." assert X_variance is None, "cannot use X_variance with DTC. Try varDTC."
#TODO: MAX! fix this!
from ...util.misc import param_to_array
Y = param_to_array(Y)
num_inducing, _ = Z.shape num_inducing, _ = Z.shape
num_data, output_dim = Y.shape num_data, output_dim = Y.shape
#make sure the noise is not hetero #make sure the noise is not hetero
beta = 1./np.squeeze(likelihood.variance) beta = 1./likelihood.gaussian_variance(Y_metadata)
if beta.size <1: if beta.size > 1:
raise NotImplementedError, "no hetero noise with this implementation of DTC" raise NotImplementedError, "no hetero noise with this implementation of DTC"
Kmm = kern.K(Z) Kmm = kern.K(Z)
@ -91,19 +87,15 @@ class vDTC(object):
def __init__(self): def __init__(self):
self.const_jitter = 1e-6 self.const_jitter = 1e-6
def inference(self, kern, X, X_variance, Z, likelihood, Y): def inference(self, kern, X, X_variance, Z, likelihood, Y, Y_metadata):
assert X_variance is None, "cannot use X_variance with DTC. Try varDTC." assert X_variance is None, "cannot use X_variance with DTC. Try varDTC."
#TODO: MAX! fix this!
from ...util.misc import param_to_array
Y = param_to_array(Y)
num_inducing, _ = Z.shape num_inducing, _ = Z.shape
num_data, output_dim = Y.shape num_data, output_dim = Y.shape
#make sure the noise is not hetero #make sure the noise is not hetero
beta = 1./np.squeeze(likelihood.variance) beta = 1./likelihood.gaussian_variance(Y_metadata)
if beta.size <1: if beta.size > 1:
raise NotImplementedError, "no hetero noise with this implementation of DTC" raise NotImplementedError, "no hetero noise with this implementation of DTC"
Kmm = kern.K(Z) Kmm = kern.K(Z)

View file

@ -17,14 +17,14 @@ class FITC(object):
""" """
const_jitter = 1e-6 const_jitter = 1e-6
def inference(self, kern, X, Z, likelihood, Y): def inference(self, kern, X, Z, likelihood, Y, Y_metadata=None):
num_inducing, _ = Z.shape num_inducing, _ = Z.shape
num_data, output_dim = Y.shape num_data, output_dim = Y.shape
#make sure the noise is not hetero #make sure the noise is not hetero
sigma_n = np.squeeze(likelihood.variance) sigma_n = likelihood.gaussian_variance(Y_metadata)
if sigma_n.size <1: if sigma_n.size >1:
raise NotImplementedError, "no hetero noise with this implementation of FITC" raise NotImplementedError, "no hetero noise with this implementation of FITC"
Kmm = kern.K(Z) Kmm = kern.K(Z)