added trace_sum for efficiency

This commit is contained in:
James Hensman 2013-03-11 18:56:37 +00:00
parent 129bb3924e
commit cb082898d3
2 changed files with 10 additions and 3 deletions

View file

@ -3,7 +3,7 @@
import numpy as np import numpy as np
import pylab as pb import pylab as pb
from ..util.linalg import mdot, jitchol, chol_inv, pdinv from ..util.linalg import mdot, jitchol, chol_inv, pdinv, trace_dot
from ..util.plot import gpplot from ..util.plot import gpplot
from .. import kern from .. import kern
from GP import GP from GP import GP
@ -107,7 +107,7 @@ class sparse_GP(GP):
self.C = np.dot(tmp,tmp.T) self.C = np.dot(tmp,tmp.T)
#self.C = mdot(self.Lmi.T, self.Bi, self.Lmi) #self.C = mdot(self.Lmi.T, self.Bi, self.Lmi)
#self.E = mdot(self.C, self.psi1VVpsi1/sf2, self.C.T) #self.E = mdot(self.C, self.psi1VVpsi1/sf2, self.C.T)
tmp = np.dot(self.C,self.psi1V/sf) tmp = np.dot(self.C/sf,self.psi1V)
self.E = np.dot(tmp,tmp.T) self.E = np.dot(tmp,tmp.T)
# Compute dL_dpsi # FIXME: this is untested for the heterscedastic + uncertin inputs case # Compute dL_dpsi # FIXME: this is untested for the heterscedastic + uncertin inputs case
@ -156,7 +156,7 @@ class sparse_GP(GP):
beta = self.likelihood.precision beta = self.likelihood.precision
dbeta = 0.5 * self.N*self.D/beta - 0.5 * np.sum(np.square(self.likelihood.Y)) dbeta = 0.5 * self.N*self.D/beta - 0.5 * np.sum(np.square(self.likelihood.Y))
dbeta += - 0.5 * self.D * (self.psi0.sum() - np.trace(self.A)/beta*sf2) dbeta += - 0.5 * self.D * (self.psi0.sum() - np.trace(self.A)/beta*sf2)
dbeta += - 0.5 * self.D * np.sum(self.Bi*self.A)/beta dbeta += - 0.5 * self.D * trace_dot(self.Bi,self.A)/beta
dbeta += np.sum((self.C - 0.5 * mdot(self.C,self.psi2_beta_scaled,self.C) ) * self.psi1VVpsi1 )/beta dbeta += np.sum((self.C - 0.5 * mdot(self.C,self.psi2_beta_scaled,self.C) ) * self.psi1VVpsi1 )/beta
self.partial_for_likelihood = -dbeta*self.likelihood.precision**2 self.partial_for_likelihood = -dbeta*self.likelihood.precision**2

View file

@ -14,6 +14,13 @@ import types
#import scipy.lib.lapack.flapack #import scipy.lib.lapack.flapack
import scipy as sp import scipy as sp
def trace_dot(a,b):
"""
efficiently compute the trace of the matrix product of a and b
"""
assert a.shape==b.T.shape
return np.dot(a.flatten(),b.T.flatten())
def mdot(*args): def mdot(*args):
"""Multiply all the arguments using matrix product rules. """Multiply all the arguments using matrix product rules.
The output is equivalent to multiplying the arguments one by one The output is equivalent to multiplying the arguments one by one