[stationary] lengthscales will be scaled by variance now

This commit is contained in:
Max Zwiessele 2014-11-28 10:10:52 +00:00
parent be40307597
commit 45ede97d85

View file

@ -159,7 +159,7 @@ class Stationary(Kern):
#self.lengthscale.gradient = -((dL_dr*rinv)[:,:,None]*x_xl3).sum(0).sum(0)/self.lengthscale**3 #self.lengthscale.gradient = -((dL_dr*rinv)[:,:,None]*x_xl3).sum(0).sum(0)/self.lengthscale**3
tmp = dL_dr*self._inv_dist(X, X2) tmp = dL_dr*self._inv_dist(X, X2)
if X2 is None: X2 = X if X2 is None: X2 = X
if config.getboolean('weave', 'working'): if config.getboolean('weave', 'working'):
try: try:
@ -261,7 +261,7 @@ class Stationary(Kern):
ret(n,d) = retnd; ret(n,d) = retnd;
} }
} }
""" """
if hasattr(X, 'values'):X = X.values #remove the GPy wrapping to make passing into weave safe if hasattr(X, 'values'):X = X.values #remove the GPy wrapping to make passing into weave safe
if hasattr(X2, 'values'):X2 = X2.values if hasattr(X2, 'values'):X2 = X2.values
@ -278,12 +278,12 @@ class Stationary(Kern):
'extra_link_args' : ['-lgomp']} 'extra_link_args' : ['-lgomp']}
weave.inline(code, ['ret', 'N', 'D', 'M', 'tmp', 'X', 'X2'], type_converters=weave.converters.blitz, support_code=support_code, **weave_options) weave.inline(code, ['ret', 'N', 'D', 'M', 'tmp', 'X', 'X2'], type_converters=weave.converters.blitz, support_code=support_code, **weave_options)
return ret/self.lengthscale**2 return ret/self.lengthscale**2
def gradients_X_diag(self, dL_dKdiag, X): def gradients_X_diag(self, dL_dKdiag, X):
return np.zeros(X.shape) return np.zeros(X.shape)
def input_sensitivity(self, summarize=True): def input_sensitivity(self, summarize=True):
return np.ones(self.input_dim)/self.lengthscale**2 return self.variance*np.ones(self.input_dim)/self.lengthscale**2
class Exponential(Stationary): class Exponential(Stationary):
def __init__(self, input_dim, variance=1., lengthscale=None, ARD=False, active_dims=None, name='Exponential'): def __init__(self, input_dim, variance=1., lengthscale=None, ARD=False, active_dims=None, name='Exponential'):