mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-24 14:15:14 +02:00
added some caching
This commit is contained in:
parent
a6d3fda234
commit
283b0745aa
2 changed files with 11 additions and 10 deletions
|
|
@ -49,13 +49,8 @@ class Linear(Kern):
|
||||||
|
|
||||||
self.variances = Param('variances', variances, Logexp())
|
self.variances = Param('variances', variances, Logexp())
|
||||||
self.add_parameter(self.variances)
|
self.add_parameter(self.variances)
|
||||||
self.variances.add_observer(self, self._on_changed)
|
|
||||||
|
|
||||||
def _on_changed(self, obj):
|
@Cache_this(limit=2)
|
||||||
#TODO: move this to base class? isnt it jst for the caching?
|
|
||||||
self._notify_observers()
|
|
||||||
|
|
||||||
#@cache_this(limit=3, reset_on_self=True)
|
|
||||||
def K(self, X, X2=None):
|
def K(self, X, X2=None):
|
||||||
if self.ARD:
|
if self.ARD:
|
||||||
if X2 is None:
|
if X2 is None:
|
||||||
|
|
@ -66,7 +61,7 @@ class Linear(Kern):
|
||||||
else:
|
else:
|
||||||
return self._dot_product(X, X2) * self.variances
|
return self._dot_product(X, X2) * self.variances
|
||||||
|
|
||||||
#@cache_this(limit=3, reset_on_self=False)
|
@Cache_this(limit=1, ignore_args(0,))
|
||||||
def _dot_product(self, X, X2=None):
|
def _dot_product(self, X, X2=None):
|
||||||
if X2 is None:
|
if X2 is None:
|
||||||
return tdot(X)
|
return tdot(X)
|
||||||
|
|
@ -113,6 +108,7 @@ class Linear(Kern):
|
||||||
def psi1(self, Z, variational_posterior):
|
def psi1(self, Z, variational_posterior):
|
||||||
return self.K(variational_posterior.mean, Z) #the variance, it does nothing
|
return self.K(variational_posterior.mean, Z) #the variance, it does nothing
|
||||||
|
|
||||||
|
@Cache_this(limit=1)
|
||||||
def psi2(self, Z, variational_posterior):
|
def psi2(self, Z, variational_posterior):
|
||||||
ZA = Z * self.variances
|
ZA = Z * self.variances
|
||||||
ZAinner = self._ZAinner(variational_posterior, Z)
|
ZAinner = self._ZAinner(variational_posterior, Z)
|
||||||
|
|
@ -126,9 +122,11 @@ class Linear(Kern):
|
||||||
if self.ARD: self.variances.gradient += tmp.sum(0)
|
if self.ARD: self.variances.gradient += tmp.sum(0)
|
||||||
else: self.variances.gradient += tmp.sum()
|
else: self.variances.gradient += tmp.sum()
|
||||||
#psi2
|
#psi2
|
||||||
tmp = dL_dpsi2[:, :, :, None] * (self._ZAinner(variational_posterior, Z)[:, :, None, :] * (2. * Z)[None, None, :, :])
|
if self.ARD:
|
||||||
if self.ARD: self.variances.gradient += tmp.sum(0).sum(0).sum(0)
|
tmp = dL_dpsi2[:, :, :, None] * (self._ZAinner(variational_posterior, Z)[:, :, None, :] * Z[None, None, :, :])
|
||||||
else: self.variances.gradient += tmp.sum()
|
self.variances.gradient += 2.*tmp.sum(0).sum(0).sum(0)
|
||||||
|
else:
|
||||||
|
self.variances.gradient += 2.*np.sum(dL_dpsi2 * self.psi2(Z, variational_posterior))/self.variances
|
||||||
|
|
||||||
def gradients_Z_expectations(self, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
def gradients_Z_expectations(self, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||||
#psi1
|
#psi1
|
||||||
|
|
@ -234,9 +232,11 @@ class Linear(Kern):
|
||||||
type_converters=weave.converters.blitz,**weave_options)
|
type_converters=weave.converters.blitz,**weave_options)
|
||||||
|
|
||||||
|
|
||||||
|
@Cache_this(limit=1, ignore_args=(0,))
|
||||||
def _mu2S(self, vp):
|
def _mu2S(self, vp):
|
||||||
return np.square(vp.mean) + vp.variance
|
return np.square(vp.mean) + vp.variance
|
||||||
|
|
||||||
|
@Cache_this(limit=1)
|
||||||
def _ZAinner(self, vp, Z):
|
def _ZAinner(self, vp, Z):
|
||||||
ZA = Z*self.variances
|
ZA = Z*self.variances
|
||||||
inner = (vp.mean[:, None, :] * vp.mean[:, :, None])
|
inner = (vp.mean[:, None, :] * vp.mean[:, :, None])
|
||||||
|
|
|
||||||
|
|
@ -113,6 +113,7 @@ class RBF(Stationary):
|
||||||
# Precomputations #
|
# Precomputations #
|
||||||
#---------------------------------------#
|
#---------------------------------------#
|
||||||
|
|
||||||
|
#TODO: this function is unused, but it will be useful in the stationary class
|
||||||
def _dL_dlengthscales_via_K(self, dL_dK, X, X2):
|
def _dL_dlengthscales_via_K(self, dL_dK, X, X2):
|
||||||
"""
|
"""
|
||||||
A helper function for update_gradients_* methods
|
A helper function for update_gradients_* methods
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue