diff --git a/GPy/core/svgp.py b/GPy/core/svgp.py index 06a9749c..b8df625e 100644 --- a/GPy/core/svgp.py +++ b/GPy/core/svgp.py @@ -9,7 +9,7 @@ from ..inference.latent_function_inference import SVGP as svgp_inf class SVGP(SparseGP): - def __init__(self, X, Y, Z, kernel, likelihood, mean_function=None, name='SVGP', Y_metadata=None, batchsize=None): + def __init__(self, X, Y, Z, kernel, likelihood, mean_function=None, name='SVGP', Y_metadata=None, batchsize=None, num_latent_functions=None): """ Stochastic Variational GP. @@ -41,8 +41,12 @@ class SVGP(SparseGP): SparseGP.__init__(self, X_batch, Y_batch, Z, kernel, likelihood, mean_function=mean_function, inference_method=inf_method, name=name, Y_metadata=Y_metadata, normalizer=False) - self.m = Param('q_u_mean', np.zeros((self.num_inducing, Y.shape[1]))) - chol = choleskies.triang_to_flat(np.tile(np.eye(self.num_inducing)[:,:,None], (1,1,Y.shape[1]))) + #assume the number of latent functions is one per col of Y unless specified + if num_latent_functions is None: + num_latent_functions = Y.shape[1] + + self.m = Param('q_u_mean', np.zeros((self.num_inducing, num_latent_functions))) + chol = choleskies.triang_to_flat(np.tile(np.eye(self.num_inducing)[:,:,None], (1,1,num_latent_functions))) self.chol = Param('q_u_chol', chol) self.link_parameter(self.chol) self.link_parameter(self.m)