added precomputation of linear kernel, changed the logic a bit

This commit is contained in:
Nicolo Fusi 2012-12-03 10:08:12 +00:00
parent 46754db658
commit 8b73dafbae

View file

@ -21,6 +21,7 @@ class linear(kernpart):
self.Nparam = 1
self.name = 'linear'
self.set_param(variance)
self._Xcache, self._X2cache = np.empty(shape=(2,))
def get_param(self):
return self.variance
@ -32,7 +33,8 @@ class linear(kernpart):
return ['variance']
def K(self,X,X2,target):
target += self.variance * np.dot(X, X2.T)
self._K_computations(X, X2)
target += self.variance * self._dot_product
def Kdiag(self,X,target):
np.add(target,np.sum(self.variance*np.square(X),-1),target)
@ -42,7 +44,9 @@ class linear(kernpart):
Computes the derivatives wrt theta
Return shape is NxMx(Ntheta)
"""
product = np.dot(X, X2.T)
self._K_computations(X, X2)
product = self._dot_product
# product = np.dot(X, X2.T)
target += np.sum(product*partial)
def dK_dX(self,partial,X,X2,target):
@ -51,6 +55,20 @@ class linear(kernpart):
def dKdiag_dtheta(self,partial,X,target):
target += np.sum(partial*np.square(X).sum(1))
def _K_computations(self,X,X2):
# (Nicolo) changed the logic here. If X2 is None, we want to cache
# (X,X). In practice X2 should always be passed.
if X2 is None:
X2 = X
if not (np.all(X==self._Xcache) and np.all(X2==self._X2cache)):
self._Xcache = X
self._X2cache = X2
self._dot_product = np.dot(X,X2.T)
else:
# print "Cache hit!"
pass # TODO: insert debug message here (logging framework)
# def psi0(self,Z,mu,S,target):
# expected = np.square(mu) + S
# np.add(target,np.sum(self.variance*expected),target)