mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-08 11:32:39 +02:00
[kern psi2] added flag for returning psi2 in N, not used yet, see #139
This commit is contained in:
parent
829e40b25c
commit
c128c6f948
3 changed files with 22 additions and 8 deletions
|
|
@ -6,6 +6,7 @@ import numpy as np
|
||||||
from ...core.parameterization.parameterized import Parameterized
|
from ...core.parameterization.parameterized import Parameterized
|
||||||
from kernel_slice_operations import KernCallsViaSlicerMeta
|
from kernel_slice_operations import KernCallsViaSlicerMeta
|
||||||
from ...util.caching import Cache_this
|
from ...util.caching import Cache_this
|
||||||
|
from GPy.core.parameterization.observable_array import ObsAr
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -54,6 +55,20 @@ class Kern(Parameterized):
|
||||||
|
|
||||||
self._sliced_X = 0
|
self._sliced_X = 0
|
||||||
self.useGPU = self._support_GPU and useGPU
|
self.useGPU = self._support_GPU and useGPU
|
||||||
|
self._return_psi2_n_flag = ObsAr(np.zeros(1)).astype(bool)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def return_psi2_n(self):
|
||||||
|
"""
|
||||||
|
Flag whether to pass back psi2 as NxMxM or MxM, by summing out N.
|
||||||
|
"""
|
||||||
|
return self._return_psi2_n_flag[0]
|
||||||
|
@return_psi2_n.setter
|
||||||
|
def return_psi2_n(self, val):
|
||||||
|
def visit(self):
|
||||||
|
if isinstance(self, Kern):
|
||||||
|
self._return_psi2_n_flag[0]=val
|
||||||
|
self.traverse(visit)
|
||||||
|
|
||||||
@Cache_this(limit=20)
|
@Cache_this(limit=20)
|
||||||
def _slice_X(self, X):
|
def _slice_X(self, X):
|
||||||
|
|
@ -162,7 +177,7 @@ class Kern(Parameterized):
|
||||||
def __mul__(self, other):
|
def __mul__(self, other):
|
||||||
""" Here we overload the '*' operator. See self.prod for more information"""
|
""" Here we overload the '*' operator. See self.prod for more information"""
|
||||||
return self.prod(other)
|
return self.prod(other)
|
||||||
|
|
||||||
def __imul__(self, other):
|
def __imul__(self, other):
|
||||||
""" Here we overload the '*' operator. See self.prod for more information"""
|
""" Here we overload the '*' operator. See self.prod for more information"""
|
||||||
return self.prod(other)
|
return self.prod(other)
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ import sslinear_psi_comp
|
||||||
import linear_psi_comp
|
import linear_psi_comp
|
||||||
|
|
||||||
class PSICOMP_RBF(Pickleable):
|
class PSICOMP_RBF(Pickleable):
|
||||||
|
|
||||||
@Cache_this(limit=2, ignore_args=(0,))
|
@Cache_this(limit=2, ignore_args=(0,))
|
||||||
def psicomputations(self, variance, lengthscale, Z, variational_posterior):
|
def psicomputations(self, variance, lengthscale, Z, variational_posterior):
|
||||||
if isinstance(variational_posterior, variational.NormalPosterior):
|
if isinstance(variational_posterior, variational.NormalPosterior):
|
||||||
|
|
@ -19,7 +18,7 @@ class PSICOMP_RBF(Pickleable):
|
||||||
return ssrbf_psi_comp.psicomputations(variance, lengthscale, Z, variational_posterior)
|
return ssrbf_psi_comp.psicomputations(variance, lengthscale, Z, variational_posterior)
|
||||||
else:
|
else:
|
||||||
raise ValueError, "unknown distriubtion received for psi-statistics"
|
raise ValueError, "unknown distriubtion received for psi-statistics"
|
||||||
|
|
||||||
@Cache_this(limit=2, ignore_args=(0,1,2,3))
|
@Cache_this(limit=2, ignore_args=(0,1,2,3))
|
||||||
def psiDerivativecomputations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, lengthscale, Z, variational_posterior):
|
def psiDerivativecomputations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, lengthscale, Z, variational_posterior):
|
||||||
if isinstance(variational_posterior, variational.NormalPosterior):
|
if isinstance(variational_posterior, variational.NormalPosterior):
|
||||||
|
|
@ -28,10 +27,10 @@ class PSICOMP_RBF(Pickleable):
|
||||||
return ssrbf_psi_comp.psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, lengthscale, Z, variational_posterior)
|
return ssrbf_psi_comp.psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, lengthscale, Z, variational_posterior)
|
||||||
else:
|
else:
|
||||||
raise ValueError, "unknown distriubtion received for psi-statistics"
|
raise ValueError, "unknown distriubtion received for psi-statistics"
|
||||||
|
|
||||||
def _setup_observers(self):
|
def _setup_observers(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class PSICOMP_Linear(Pickleable):
|
class PSICOMP_Linear(Pickleable):
|
||||||
|
|
||||||
@Cache_this(limit=2, ignore_args=(0,))
|
@Cache_this(limit=2, ignore_args=(0,))
|
||||||
|
|
@ -42,7 +41,7 @@ class PSICOMP_Linear(Pickleable):
|
||||||
return sslinear_psi_comp.psicomputations(variance, Z, variational_posterior)
|
return sslinear_psi_comp.psicomputations(variance, Z, variational_posterior)
|
||||||
else:
|
else:
|
||||||
raise ValueError, "unknown distriubtion received for psi-statistics"
|
raise ValueError, "unknown distriubtion received for psi-statistics"
|
||||||
|
|
||||||
@Cache_this(limit=2, ignore_args=(0,1,2,3))
|
@Cache_this(limit=2, ignore_args=(0,1,2,3))
|
||||||
def psiDerivativecomputations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, Z, variational_posterior):
|
def psiDerivativecomputations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, Z, variational_posterior):
|
||||||
if isinstance(variational_posterior, variational.NormalPosterior):
|
if isinstance(variational_posterior, variational.NormalPosterior):
|
||||||
|
|
@ -51,6 +50,6 @@ class PSICOMP_Linear(Pickleable):
|
||||||
return sslinear_psi_comp.psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, Z, variational_posterior)
|
return sslinear_psi_comp.psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, Z, variational_posterior)
|
||||||
else:
|
else:
|
||||||
raise ValueError, "unknown distriubtion received for psi-statistics"
|
raise ValueError, "unknown distriubtion received for psi-statistics"
|
||||||
|
|
||||||
def _setup_observers(self):
|
def _setup_observers(self):
|
||||||
pass
|
pass
|
||||||
|
|
@ -139,7 +139,7 @@ def _psi2compDer(dL_dpsi2, variance, lengthscale, Z, mu, S):
|
||||||
denom2 = np.square(denom)
|
denom2 = np.square(denom)
|
||||||
|
|
||||||
_psi2 = _psi2computations(variance, lengthscale, Z, mu, S) # NxMxM
|
_psi2 = _psi2computations(variance, lengthscale, Z, mu, S) # NxMxM
|
||||||
Lpsi2 = dL_dpsi2[None,:,:]*_psi2
|
Lpsi2 = dL_dpsi2*_psi2 # dL_dpsi2 is MxM, using broadcast to multiply N out
|
||||||
Lpsi2sum = np.einsum('nmo->n',Lpsi2) #N
|
Lpsi2sum = np.einsum('nmo->n',Lpsi2) #N
|
||||||
Lpsi2Z = np.einsum('nmo,oq->nq',Lpsi2,Z) #NxQ
|
Lpsi2Z = np.einsum('nmo,oq->nq',Lpsi2,Z) #NxQ
|
||||||
Lpsi2Z2 = np.einsum('nmo,oq,oq->nq',Lpsi2,Z,Z) #NxQ
|
Lpsi2Z2 = np.einsum('nmo,oq,oq->nq',Lpsi2,Z,Z) #NxQ
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue