input senitivity in stationary

This commit is contained in:
James Hensman 2014-02-24 14:49:20 +00:00
parent d3eaef5c99
commit 06dd27c634

View file

@ -35,6 +35,9 @@ class Stationary(Kern):
def K_of_r(self, r):
raise NotImplementedError, "implement the covaraiance functino and a fn of r to use this class"
def dK_dr(self, r):
raise NotImplementedError, "implement the covaraiance functino and a fn of r to use this class"
def K(self, X, X2=None):
r = self._scaled_dist(X, X2)
return self.K_of_r(r)
@ -92,55 +95,8 @@ class Stationary(Kern):
def gradients_X_diag(self, dL_dKdiag, X):
return np.zeros(X.shape)
def add(self, other, tensor=False):
if not tensor:
return StatAdd(self, other)
else:
return super(Stationary, self).add(other, tensor)
def prod(self, other, tensor=False):
if not tensor:
return StatProd(self, other)
else:
return super(Stationary, self).prod(other, tensor)
class StatAdd(Stationary):
"""
Addition of two Stationary kernels on the same space is still stationary.
If you need to add two (stationary) kernels on separate spaces, use the generic add class.
"""
def __init__(self, k1, k2):
assert isinstance(k1, Stationary)
assert isinstance(k2, Stationary)
self.k1, self.k2 = k1, k2
self.add_parameters(k1, k2)
def K_of_r(self, r):
return self.k1.K(r) + self.k2.K(r)
def dK_dr(self, r):
return self.k1.dK_dr + self.k2.dK_dr(r)
class StatProd(Stationary):
"""
Product of two Stationary kernels on the same space is still stationary.
If you need to multiply two (stationary) kernels on separate spaces, use the generic Prod class.
"""
def __init__(self, k1, k2):
assert isinstance(k1, Stationary)
assert isinstance(k2, Stationary)
self.k1, self.k2 = k1, k2
self.add_parameters(k1, k2)
def K_of_r(self, r):
return self.k1.K(r) * self.k2.K(r)
def dK_dr(self, r):
return self.k1.dK_dr(r) * self.k2.K_of_r(r) + self.k2.dK_dr(r) * self.k1.K_of_r(r)
def input_sensitivity(self):
return np.ones(self.input_dim)/self.lengthscale
class Exponential(Stationary):
def __init__(self, input_dim, variance=1., lengthscale=None, ARD=False, name='Exponential'):