mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 12:02:38 +02:00
fixed merge conflict on BPGLVM
This commit is contained in:
commit
a42d84274d
10 changed files with 367 additions and 67 deletions
|
|
@ -52,7 +52,7 @@ class kern(parameterised):
|
|||
parameterised.__init__(self)
|
||||
|
||||
|
||||
def plot_ARD(self):
|
||||
def plot_ARD(self, ax=pb.gca()):
|
||||
"""
|
||||
If an ARD kernel is present, it bar-plots the ARD parameters
|
||||
|
||||
|
|
@ -60,17 +60,17 @@ class kern(parameterised):
|
|||
"""
|
||||
for p in self.parts:
|
||||
if hasattr(p, 'ARD') and p.ARD:
|
||||
pb.figure()
|
||||
pb.title('ARD parameters, %s kernel' % p.name)
|
||||
ax.set_title('ARD parameters, %s kernel' % p.name)
|
||||
|
||||
if p.name == 'linear':
|
||||
ard_params = p.variances
|
||||
else:
|
||||
ard_params = 1./p.lengthscale
|
||||
|
||||
pb.bar(np.arange(len(ard_params))-0.4, ard_params)
|
||||
|
||||
|
||||
ax.bar(np.arange(len(ard_params)) - 0.4, ard_params)
|
||||
ax.set_xticks(np.arange(len(ard_params)),
|
||||
["${}$".format(i + 1) for i in range(len(ard_params))])
|
||||
return ax
|
||||
|
||||
def _transform_gradients(self,g):
|
||||
x = self._get_params()
|
||||
|
|
|
|||
|
|
@ -173,7 +173,7 @@ class rbf(kernpart):
|
|||
"""Think N,M,M,Q """
|
||||
self._psi_computations(Z,mu,S)
|
||||
tmp = self._psi2[:,:,:,None]/self.lengthscale2/self._psi2_denom
|
||||
target_mu += (dL_dpsi2[:,:,:,None]*-tmp*2.*self._psi2_mudist).sum(1).sum(1)
|
||||
target_mu += -2.*(dL_dpsi2[:,:,:,None]*tmp*self._psi2_mudist).sum(1).sum(1)
|
||||
target_S += (dL_dpsi2[:,:,:,None]*tmp*(2.*self._psi2_mudist_sq-1)).sum(1).sum(1)
|
||||
|
||||
|
||||
|
|
@ -207,7 +207,6 @@ class rbf(kernpart):
|
|||
if not (np.all(Z==self._Z) and np.all(mu==self._mu) and np.all(S==self._S)):
|
||||
#something's changed. recompute EVERYTHING
|
||||
|
||||
#TODO: make more efficient for large Q (using NDL's dot product trick)
|
||||
#psi1
|
||||
self._psi1_denom = S[:,None,:]/self.lengthscale2 + 1.
|
||||
self._psi1_dist = Z[None,:,:]-mu[:,None,:]
|
||||
|
|
@ -249,7 +248,7 @@ class rbf(kernpart):
|
|||
_psi2_denom = self._psi2_denom.squeeze()
|
||||
code = """
|
||||
double tmp;
|
||||
|
||||
|
||||
#pragma omp parallel for private(tmp)
|
||||
for (int n=0; n<N; n++){
|
||||
for (int m=0; m<M; m++){
|
||||
|
|
@ -286,8 +285,8 @@ class rbf(kernpart):
|
|||
#include <omp.h>
|
||||
#include <math.h>
|
||||
"""
|
||||
weave.inline(code, support_code=support_code, libraries=['gomp'],
|
||||
weave.inline(code, support_code=support_code, libraries=['gomp'],
|
||||
arg_names=['N','M','Q','mu','Zhat','mudist_sq','mudist','lengthscale2','_psi2_denom','psi2_Zdist_sq','psi2_exponent','half_log_psi2_denom','psi2','variance_sq'],
|
||||
type_converters=weave.converters.blitz,**weave_options)
|
||||
|
||||
|
||||
return mudist,mudist_sq, psi2_exponent, psi2
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue