kernel adding now takes over constraints

This commit is contained in:
Max Zwiessele 2014-02-11 15:23:49 +00:00
parent b7312a1b99
commit 4cfc13d5fc
4 changed files with 9 additions and 12 deletions

View file

@ -7,6 +7,7 @@ from gp import GP
from parameterization.param import Param
from ..inference.latent_function_inference import varDTC
from .. import likelihoods
from GPy.util.misc import param_to_array
class SparseGP(GP):
"""
@ -54,7 +55,10 @@ class SparseGP(GP):
self.add_parameter(self.Z, index=0)
def parameters_changed(self):
self.posterior, self._log_marginal_likelihood, self.grad_dict = self.inference_method.inference(self.kern, self.X, self.X_variance, self.Z, self.likelihood, self.Y)
Xvar = self.X_variance
if self.X_variance is not None:
Xvar = param_to_array(self.X_variance)
self.posterior, self._log_marginal_likelihood, self.grad_dict = self.inference_method.inference(self.kern, param_to_array(self.X), Xvar, param_to_array(self.Z), self.likelihood, self.Y)
#The derivative of the bound wrt the inducing inputs Z
self.Z.gradient = self.kern.gradients_X(self.grad_dict['dL_dKmm'], self.Z)

View file

@ -6,6 +6,7 @@ from ...util.linalg import jitchol, backsub_both_sides, tdot, dtrtrs, dpotri, sy
import numpy as np
from ...util.linalg import dtrtri
from ...util.caching import Cacher
from ...util.misc import param_to_array
log_2_pi = np.log(2*np.pi)
class VarDTC(object):
@ -25,7 +26,7 @@ class VarDTC(object):
self.get_YYTfactor = Cacher(self._get_YYTfactor, 1)
def _get_trYYT(self, Y):
return np.sum(np.square(Y))
return param_to_array(np.sum(np.square(Y)))
def _get_YYTfactor(self, Y):
"""
@ -35,7 +36,7 @@ class VarDTC(object):
"""
N, D = Y.shape
if (N>D):
return Y
return param_to_array(Y)
else:
return jitchol(tdot(Y))

View file

@ -71,16 +71,13 @@ class DPsiStatTest(unittest.TestCase):
for k in self.kernels:
m = PsiStatModel('psi0', X=self.X, X_variance=self.X_var, Z=self.Z,\
num_inducing=self.num_inducing, kernel=k)
#m.ensure_default_constraints(warning=0)
m.randomize()
import ipdb;ipdb.set_trace()
assert m.checkgrad(), "{} x psi0".format("+".join(map(lambda x: x.name, k._parameters_)))
def testPsi1(self):
for k in self.kernels:
m = PsiStatModel('psi1', X=self.X, X_variance=self.X_var, Z=self.Z,
num_inducing=self.num_inducing, kernel=k)
m.ensure_default_constraints(warning=0)
m.randomize()
assert m.checkgrad(), "{} x psi1".format("+".join(map(lambda x: x.name, k._parameters_)))
@ -88,35 +85,30 @@ class DPsiStatTest(unittest.TestCase):
k = self.kernels[0]
m = PsiStatModel('psi2', X=self.X, X_variance=self.X_var, Z=self.Z,
num_inducing=self.num_inducing, kernel=k)
m.ensure_default_constraints(warning=0)
m.randomize()
assert m.checkgrad(), "{} x psi2".format("+".join(map(lambda x: x.name, k._parameters_)))
def testPsi2_lin_bia(self):
k = self.kernels[3]
m = PsiStatModel('psi2', X=self.X, X_variance=self.X_var, Z=self.Z,
num_inducing=self.num_inducing, kernel=k)
m.ensure_default_constraints(warning=0)
m.randomize()
assert m.checkgrad(), "{} x psi2".format("+".join(map(lambda x: x.name, k._parameters_)))
def testPsi2_rbf(self):
k = self.kernels[1]
m = PsiStatModel('psi2', X=self.X, X_variance=self.X_var, Z=self.Z,
num_inducing=self.num_inducing, kernel=k)
m.ensure_default_constraints(warning=0)
m.randomize()
assert m.checkgrad(), "{} x psi2".format("+".join(map(lambda x: x.name, k._parameters_)))
def testPsi2_rbf_bia(self):
k = self.kernels[-1]
m = PsiStatModel('psi2', X=self.X, X_variance=self.X_var, Z=self.Z,
num_inducing=self.num_inducing, kernel=k)
m.ensure_default_constraints(warning=0)
m.randomize()
assert m.checkgrad(), "{} x psi2".format("+".join(map(lambda x: x.name, k._parameters_)))
def testPsi2_bia(self):
k = self.kernels[2]
m = PsiStatModel('psi2', X=self.X, X_variance=self.X_var, Z=self.Z,
num_inducing=self.num_inducing, kernel=k)
m.ensure_default_constraints(warning=0)
m.randomize()
assert m.checkgrad(), "{} x psi2".format("+".join(map(lambda x: x.name, k._parameters_)))

View file

@ -184,4 +184,4 @@ from :class:ndarray)"""
assert len(param) > 0, "At least one parameter needed"
if len(param) == 1:
return param[0].view(np.ndarray)
return map(lambda x: x.view(np.ndarray), param)
return [x.view(np.ndarray) for x in param]