weaved some rbf code

This commit is contained in:
James Hensman 2013-04-10 15:50:31 +01:00
parent eb1d8f211f
commit 4aca883df3
2 changed files with 54 additions and 5 deletions

View file

@ -229,13 +229,18 @@ class model(parameterised):
Ensure that any variables which should clearly be positive have been constrained somehow. Ensure that any variables which should clearly be positive have been constrained somehow.
""" """
positive_strings = ['variance','lengthscale', 'precision'] positive_strings = ['variance','lengthscale', 'precision']
param_names = self._get_param_names()
currently_constrained = self.all_constrained_indices()
to_make_positive = []
for s in positive_strings: for s in positive_strings:
for i in self.grep_param_names(s): for i in self.grep_param_names(s):
if not (i in self.all_constrained_indices()): if not (i in currently_constrained):
name = self._get_param_names()[i] to_make_positive.append(param_names[i])
self.constrain_positive(name)
if warn: if warn:
print "Warning! constraining %s postive"%name print "Warning! constraining %s postive"%name
if len(to_make_positive):
self.constrain_positive('('+'|'.join(to_make_positive)+')')
def objective_function(self, x): def objective_function(self, x):

View file

@ -5,6 +5,7 @@
from kernpart import kernpart from kernpart import kernpart
import numpy as np import numpy as np
import hashlib import hashlib
from scipy import weave
class rbf(kernpart): class rbf(kernpart):
""" """
@ -217,8 +218,51 @@ class rbf(kernpart):
#psi2 #psi2
self._psi2_denom = 2.*S[:,None,None,:]/self.lengthscale2+1. # N,M,M,Q self._psi2_denom = 2.*S[:,None,None,:]/self.lengthscale2+1. # N,M,M,Q
self._psi2_mudist = mu[:,None,None,:]-self._psi2_Zhat #N,M,M,Q self._psi2_mudist = mu[:,None,None,:]-self._psi2_Zhat #N,M,M,Q
self._psi2_mudist_sq = np.square(self._psi2_mudist)/(self.lengthscale2*self._psi2_denom) self._psi2_mudist_sq, self._psi2_exponent, _ = self.weave_stuff()
self._psi2_exponent = np.sum(-self._psi2_Zdist_sq/4. -self._psi2_mudist_sq -0.5*np.log(self._psi2_denom),-1) #N,M,M #self._psi2_mudist_sq = np.square(self._psi2_mudist)/(self.lengthscale2*self._psi2_denom)
#self._psi2_exponent = np.sum(-self._psi2_Zdist_sq/4. -self._psi2_mudist_sq -0.5*np.log(self._psi2_denom),-1) #N,M,M
self._psi2 = np.square(self.variance)*np.exp(self._psi2_exponent) # N,M,M self._psi2 = np.square(self.variance)*np.exp(self._psi2_exponent) # N,M,M
#store matrices for caching
self._Z, self._mu, self._S = Z, mu,S self._Z, self._mu, self._S = Z, mu,S
def weave_psi2(self):
weave_options = {'extra_compile_args': ['-O3']}
N,M,M,Q = self._psi2_mudist.shape
mudist = self._psi2_mudist
psi2_Zdist_sq = self._psi2_Zdist_sq
mudist_sq = np.empty((N,M,M,Q))
psi2_exponent = np.zeros((N,M,M))
psi2 = np.empty((N,M,M))
half_log_psi2_denom = 0.5*np.log(self._psi2_denom).squeeze()
variance_sq = float(np.square(self.variance))
if self.ARD:
lengthscale2 = self.lengthscale2
else:
lengthscale2 = np.ones(Q)*self.lengthscale2
_psi2_denom = self._psi2_denom.squeeze()
code = """
double tmp;
for (int n=0; n<N; n++){
for (int m=0; m<M; m++){
for (int mm=0; mm<(m+1); mm++){
for (int q=0; q<Q; q++){
tmp = mudist(n,m,mm,q)*mudist(n,m,mm,q)/(lengthscale2(q)*_psi2_denom(n,q));
mudist_sq(n,m,mm,q) = tmp;
mudist_sq(n,mm,m,q) = tmp;
tmp = -psi2_Zdist_sq(m,mm,q)/4.0 - tmp - half_log_psi2_denom(n,q);
psi2_exponent(n,mm,m) += tmp;
if (m !=mm){
psi2_exponent(n,m,mm) += tmp;
}
//tmp = variance_sq*exp(psi2_exponent(n,m,mm));
//psi2(n,m,mm) = tmp;
//psi2(n,mm,m) = tmp;
}
}
}
}
"""
weave.inline(code,support_code='#include "math.h"',arg_names=['N','M','Q','mudist_sq','mudist','lengthscale2','_psi2_denom','psi2_Zdist_sq','psi2_exponent','half_log_psi2_denom','psi2','variance_sq'],type_converters=weave.converters.blitz,**weave_options)
return mudist_sq, psi2_exponent, psi2