mean functions in place

This commit is contained in:
James Hensman 2015-03-23 14:47:49 +00:00
parent efb32282de
commit fa801bf46c
8 changed files with 18 additions and 9 deletions

View file

@ -20,7 +20,8 @@ class DTC(LatentFunctionInference):
def __init__(self): def __init__(self):
self.const_jitter = 1e-6 self.const_jitter = 1e-6
def inference(self, kern, X, Z, likelihood, Y, Y_metadata=None): def inference(self, kern, X, Z, likelihood, Y, mean_function=None, Y_metadata=None):
assert mean_function is None, "inference with a mean function not implemented"
assert X_variance is None, "cannot use X_variance with DTC. Try varDTC." assert X_variance is None, "cannot use X_variance with DTC. Try varDTC."
num_inducing, _ = Z.shape num_inducing, _ = Z.shape
@ -88,7 +89,8 @@ class vDTC(object):
def __init__(self): def __init__(self):
self.const_jitter = 1e-6 self.const_jitter = 1e-6
def inference(self, kern, X, X_variance, Z, likelihood, Y, Y_metadata): def inference(self, kern, X, Z, likelihood, Y, mean_function=None, Y_metadata=None):
assert mean_function is None, "inference with a mean function not implemented"
assert X_variance is None, "cannot use X_variance with DTC. Try varDTC." assert X_variance is None, "cannot use X_variance with DTC. Try varDTC."
num_inducing, _ = Z.shape num_inducing, _ = Z.shape

View file

@ -36,10 +36,12 @@ class ExactGaussianInference(LatentFunctionInference):
#print "WARNING: N>D of Y, we need caching of L, such that L*L^T = Y, returning Y still!" #print "WARNING: N>D of Y, we need caching of L, such that L*L^T = Y, returning Y still!"
return Y return Y
def inference(self, kern, X, likelihood, Y, Y_metadata=None): def inference(self, kern, X, likelihood, Y, mean_function=None, Y_metadata=None):
""" """
Returns a Posterior class containing essential quantities of the posterior Returns a Posterior class containing essential quantities of the posterior
""" """
assert mean_function is None, "inference with a mean function not implemented"
YYT_factor = self.get_YYTfactor(Y) YYT_factor = self.get_YYTfactor(Y)
K = kern.K(X) K = kern.K(X)

View file

@ -33,7 +33,8 @@ class EP(LatentFunctionInference):
# TODO: update approximation in the end as well? Maybe even with a switch? # TODO: update approximation in the end as well? Maybe even with a switch?
pass pass
def inference(self, kern, X, likelihood, Y, Y_metadata=None, Z=None): def inference(self, kern, X, likelihood, Y, mean_function=None, Y_metadata=None, Z=None):
assert mean_function is None, "inference with a mean function not implemented"
num_data, output_dim = Y.shape num_data, output_dim = Y.shape
assert output_dim ==1, "ep in 1D only (for now!)" assert output_dim ==1, "ep in 1D only (for now!)"

View file

@ -64,7 +64,8 @@ class EPDTC(LatentFunctionInference):
self.old_mutilde, self.old_vtilde = None, None self.old_mutilde, self.old_vtilde = None, None
self._ep_approximation = None self._ep_approximation = None
def inference(self, kern, X, Z, likelihood, Y, Y_metadata=None): def inference(self, kern, X, Z, likelihood, Y, mean_function=None, Y_metadata=None):
assert mean_function is None, "inference with a mean function not implemented"
num_data, output_dim = Y.shape num_data, output_dim = Y.shape
assert output_dim ==1, "ep in 1D only (for now!)" assert output_dim ==1, "ep in 1D only (for now!)"

View file

@ -18,7 +18,8 @@ class FITC(LatentFunctionInference):
""" """
const_jitter = 1e-6 const_jitter = 1e-6
def inference(self, kern, X, Z, likelihood, Y, Y_metadata=None): def inference(self, kern, X, Z, likelihood, Y, mean_function=None, Y_metadata=None):
assert mean_function is None, "inference with a mean function not implemented"
num_inducing, _ = Z.shape num_inducing, _ = Z.shape
num_data, output_dim = Y.shape num_data, output_dim = Y.shape

View file

@ -39,10 +39,11 @@ class Laplace(LatentFunctionInference):
self.first_run = True self.first_run = True
self._previous_Ki_fhat = None self._previous_Ki_fhat = None
def inference(self, kern, X, likelihood, Y, Y_metadata=None): def inference(self, kern, X, likelihood, Y, mean_function=None, Y_metadata=None):
""" """
Returns a Posterior class containing essential quantities of the posterior Returns a Posterior class containing essential quantities of the posterior
""" """
assert mean_function is None, "inference with a mean function not implemented"
# Compute K # Compute K
K = kern.K(X) K = kern.K(X)

View file

@ -6,7 +6,8 @@ from posterior import Posterior
class SVGP(LatentFunctionInference): class SVGP(LatentFunctionInference):
def inference(self, q_u_mean, q_u_chol, kern, X, Z, likelihood, Y, Y_metadata=None, KL_scale=1.0, batch_scale=1.0): def inference(self, q_u_mean, q_u_chol, kern, X, Z, likelihood, Y, mean_function=None, Y_metadata=None, KL_scale=1.0, batch_scale=1.0):
assert mean_function is None, "inference with a mean function not implemented"
num_inducing = Z.shape[0] num_inducing = Z.shape[0]
num_data, num_outputs = Y.shape num_data, num_outputs = Y.shape

View file

@ -33,7 +33,7 @@ class Linear(Mapping):
return np.dot(X, self.A) return np.dot(X, self.A)
def update_gradients(self, dL_dF, X): def update_gradients(self, dL_dF, X):
self.A.gradient = np.dot( X.T dL_dF) self.A.gradient = np.dot( X.T, dL_dF)
def gradients_X(self, dL_dF, X): def gradients_X(self, dL_dF, X):
return np.dot(dL_dF, self.A.T) return np.dot(dL_dF, self.A.T)