mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-21 14:05:14 +02:00
added some caching
This commit is contained in:
parent
a6d3fda234
commit
283b0745aa
2 changed files with 11 additions and 10 deletions
|
|
@ -49,13 +49,8 @@ class Linear(Kern):
|
|||
|
||||
self.variances = Param('variances', variances, Logexp())
|
||||
self.add_parameter(self.variances)
|
||||
self.variances.add_observer(self, self._on_changed)
|
||||
|
||||
def _on_changed(self, obj):
|
||||
#TODO: move this to base class? isnt it jst for the caching?
|
||||
self._notify_observers()
|
||||
|
||||
#@cache_this(limit=3, reset_on_self=True)
|
||||
@Cache_this(limit=2)
|
||||
def K(self, X, X2=None):
|
||||
if self.ARD:
|
||||
if X2 is None:
|
||||
|
|
@ -66,7 +61,7 @@ class Linear(Kern):
|
|||
else:
|
||||
return self._dot_product(X, X2) * self.variances
|
||||
|
||||
#@cache_this(limit=3, reset_on_self=False)
|
||||
@Cache_this(limit=1, ignore_args(0,))
|
||||
def _dot_product(self, X, X2=None):
|
||||
if X2 is None:
|
||||
return tdot(X)
|
||||
|
|
@ -113,6 +108,7 @@ class Linear(Kern):
|
|||
def psi1(self, Z, variational_posterior):
|
||||
return self.K(variational_posterior.mean, Z) #the variance, it does nothing
|
||||
|
||||
@Cache_this(limit=1)
|
||||
def psi2(self, Z, variational_posterior):
|
||||
ZA = Z * self.variances
|
||||
ZAinner = self._ZAinner(variational_posterior, Z)
|
||||
|
|
@ -126,9 +122,11 @@ class Linear(Kern):
|
|||
if self.ARD: self.variances.gradient += tmp.sum(0)
|
||||
else: self.variances.gradient += tmp.sum()
|
||||
#psi2
|
||||
tmp = dL_dpsi2[:, :, :, None] * (self._ZAinner(variational_posterior, Z)[:, :, None, :] * (2. * Z)[None, None, :, :])
|
||||
if self.ARD: self.variances.gradient += tmp.sum(0).sum(0).sum(0)
|
||||
else: self.variances.gradient += tmp.sum()
|
||||
if self.ARD:
|
||||
tmp = dL_dpsi2[:, :, :, None] * (self._ZAinner(variational_posterior, Z)[:, :, None, :] * Z[None, None, :, :])
|
||||
self.variances.gradient += 2.*tmp.sum(0).sum(0).sum(0)
|
||||
else:
|
||||
self.variances.gradient += 2.*np.sum(dL_dpsi2 * self.psi2(Z, variational_posterior))/self.variances
|
||||
|
||||
def gradients_Z_expectations(self, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||
#psi1
|
||||
|
|
@ -234,9 +232,11 @@ class Linear(Kern):
|
|||
type_converters=weave.converters.blitz,**weave_options)
|
||||
|
||||
|
||||
@Cache_this(limit=1, ignore_args=(0,))
|
||||
def _mu2S(self, vp):
|
||||
return np.square(vp.mean) + vp.variance
|
||||
|
||||
@Cache_this(limit=1)
|
||||
def _ZAinner(self, vp, Z):
|
||||
ZA = Z*self.variances
|
||||
inner = (vp.mean[:, None, :] * vp.mean[:, :, None])
|
||||
|
|
|
|||
|
|
@ -113,6 +113,7 @@ class RBF(Stationary):
|
|||
# Precomputations #
|
||||
#---------------------------------------#
|
||||
|
||||
#TODO: this function is unused, but it will be useful in the stationary class
|
||||
def _dL_dlengthscales_via_K(self, dL_dK, X, X2):
|
||||
"""
|
||||
A helper function for update_gradients_* methods
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue