mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-15 06:52:39 +02:00
metadata passing in fitc
This commit is contained in:
parent
c353ac67e6
commit
ff88845f99
2 changed files with 10 additions and 18 deletions
|
|
@ -19,19 +19,15 @@ class DTC(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.const_jitter = 1e-6
|
self.const_jitter = 1e-6
|
||||||
|
|
||||||
def inference(self, kern, X, Z, likelihood, Y):
|
def inference(self, kern, X, Z, likelihood, Y, Y_metadata=None):
|
||||||
assert X_variance is None, "cannot use X_variance with DTC. Try varDTC."
|
assert X_variance is None, "cannot use X_variance with DTC. Try varDTC."
|
||||||
|
|
||||||
#TODO: MAX! fix this!
|
|
||||||
from ...util.misc import param_to_array
|
|
||||||
Y = param_to_array(Y)
|
|
||||||
|
|
||||||
num_inducing, _ = Z.shape
|
num_inducing, _ = Z.shape
|
||||||
num_data, output_dim = Y.shape
|
num_data, output_dim = Y.shape
|
||||||
|
|
||||||
#make sure the noise is not hetero
|
#make sure the noise is not hetero
|
||||||
beta = 1./np.squeeze(likelihood.variance)
|
beta = 1./likelihood.gaussian_variance(Y_metadata)
|
||||||
if beta.size <1:
|
if beta.size > 1:
|
||||||
raise NotImplementedError, "no hetero noise with this implementation of DTC"
|
raise NotImplementedError, "no hetero noise with this implementation of DTC"
|
||||||
|
|
||||||
Kmm = kern.K(Z)
|
Kmm = kern.K(Z)
|
||||||
|
|
@ -91,19 +87,15 @@ class vDTC(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.const_jitter = 1e-6
|
self.const_jitter = 1e-6
|
||||||
|
|
||||||
def inference(self, kern, X, X_variance, Z, likelihood, Y):
|
def inference(self, kern, X, X_variance, Z, likelihood, Y, Y_metadata):
|
||||||
assert X_variance is None, "cannot use X_variance with DTC. Try varDTC."
|
assert X_variance is None, "cannot use X_variance with DTC. Try varDTC."
|
||||||
|
|
||||||
#TODO: MAX! fix this!
|
|
||||||
from ...util.misc import param_to_array
|
|
||||||
Y = param_to_array(Y)
|
|
||||||
|
|
||||||
num_inducing, _ = Z.shape
|
num_inducing, _ = Z.shape
|
||||||
num_data, output_dim = Y.shape
|
num_data, output_dim = Y.shape
|
||||||
|
|
||||||
#make sure the noise is not hetero
|
#make sure the noise is not hetero
|
||||||
beta = 1./np.squeeze(likelihood.variance)
|
beta = 1./likelihood.gaussian_variance(Y_metadata)
|
||||||
if beta.size <1:
|
if beta.size > 1:
|
||||||
raise NotImplementedError, "no hetero noise with this implementation of DTC"
|
raise NotImplementedError, "no hetero noise with this implementation of DTC"
|
||||||
|
|
||||||
Kmm = kern.K(Z)
|
Kmm = kern.K(Z)
|
||||||
|
|
@ -112,7 +104,7 @@ class vDTC(object):
|
||||||
U = Knm
|
U = Knm
|
||||||
Uy = np.dot(U.T,Y)
|
Uy = np.dot(U.T,Y)
|
||||||
|
|
||||||
#factor Kmm
|
#factor Kmm
|
||||||
Kmmi, L, Li, _ = pdinv(Kmm)
|
Kmmi, L, Li, _ = pdinv(Kmm)
|
||||||
|
|
||||||
# Compute A
|
# Compute A
|
||||||
|
|
|
||||||
|
|
@ -17,14 +17,14 @@ class FITC(object):
|
||||||
"""
|
"""
|
||||||
const_jitter = 1e-6
|
const_jitter = 1e-6
|
||||||
|
|
||||||
def inference(self, kern, X, Z, likelihood, Y):
|
def inference(self, kern, X, Z, likelihood, Y, Y_metadata=None):
|
||||||
|
|
||||||
num_inducing, _ = Z.shape
|
num_inducing, _ = Z.shape
|
||||||
num_data, output_dim = Y.shape
|
num_data, output_dim = Y.shape
|
||||||
|
|
||||||
#make sure the noise is not hetero
|
#make sure the noise is not hetero
|
||||||
sigma_n = np.squeeze(likelihood.variance)
|
sigma_n = likelihood.gaussian_variance(Y_metadata)
|
||||||
if sigma_n.size <1:
|
if sigma_n.size >1:
|
||||||
raise NotImplementedError, "no hetero noise with this implementation of FITC"
|
raise NotImplementedError, "no hetero noise with this implementation of FITC"
|
||||||
|
|
||||||
Kmm = kern.K(Z)
|
Kmm = kern.K(Z)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue