fixed merge conflict on BPGLVM

This commit is contained in:
Nicolo Fusi 2013-04-12 13:35:00 +01:00
commit a42d84274d
10 changed files with 367 additions and 67 deletions

View file

@ -52,7 +52,7 @@ class kern(parameterised):
parameterised.__init__(self)
def plot_ARD(self):
def plot_ARD(self, ax=pb.gca()):
"""
If an ARD kernel is present, it bar-plots the ARD parameters
@ -60,17 +60,17 @@ class kern(parameterised):
"""
for p in self.parts:
if hasattr(p, 'ARD') and p.ARD:
pb.figure()
pb.title('ARD parameters, %s kernel' % p.name)
ax.set_title('ARD parameters, %s kernel' % p.name)
if p.name == 'linear':
ard_params = p.variances
else:
ard_params = 1./p.lengthscale
pb.bar(np.arange(len(ard_params))-0.4, ard_params)
ax.bar(np.arange(len(ard_params)) - 0.4, ard_params)
ax.set_xticks(np.arange(len(ard_params)),
["${}$".format(i + 1) for i in range(len(ard_params))])
return ax
def _transform_gradients(self,g):
x = self._get_params()

View file

@ -173,7 +173,7 @@ class rbf(kernpart):
"""Think N,M,M,Q """
self._psi_computations(Z,mu,S)
tmp = self._psi2[:,:,:,None]/self.lengthscale2/self._psi2_denom
target_mu += (dL_dpsi2[:,:,:,None]*-tmp*2.*self._psi2_mudist).sum(1).sum(1)
target_mu += -2.*(dL_dpsi2[:,:,:,None]*tmp*self._psi2_mudist).sum(1).sum(1)
target_S += (dL_dpsi2[:,:,:,None]*tmp*(2.*self._psi2_mudist_sq-1)).sum(1).sum(1)
@ -207,7 +207,6 @@ class rbf(kernpart):
if not (np.all(Z==self._Z) and np.all(mu==self._mu) and np.all(S==self._S)):
#something's changed. recompute EVERYTHING
#TODO: make more efficient for large Q (using NDL's dot product trick)
#psi1
self._psi1_denom = S[:,None,:]/self.lengthscale2 + 1.
self._psi1_dist = Z[None,:,:]-mu[:,None,:]
@ -249,7 +248,7 @@ class rbf(kernpart):
_psi2_denom = self._psi2_denom.squeeze()
code = """
double tmp;
#pragma omp parallel for private(tmp)
for (int n=0; n<N; n++){
for (int m=0; m<M; m++){
@ -286,8 +285,8 @@ class rbf(kernpart):
#include <omp.h>
#include <math.h>
"""
weave.inline(code, support_code=support_code, libraries=['gomp'],
weave.inline(code, support_code=support_code, libraries=['gomp'],
arg_names=['N','M','Q','mu','Zhat','mudist_sq','mudist','lengthscale2','_psi2_denom','psi2_Zdist_sq','psi2_exponent','half_log_psi2_denom','psi2','variance_sq'],
type_converters=weave.converters.blitz,**weave_options)
return mudist,mudist_sq, psi2_exponent, psi2