mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-11 15:15:15 +02:00
Merge branch 'devel' of github.com:SheffieldML/GPy into devel
This commit is contained in:
commit
f15df22f43
4 changed files with 19 additions and 42 deletions
|
|
@ -100,6 +100,9 @@ class linear(Kernpart):
|
|||
def dK_dX(self, dL_dK, X, X2, target):
|
||||
target += (((X2[:, None, :] * self.variances)) * dL_dK[:, :, None]).sum(0)
|
||||
|
||||
def dKdiag_dX(self,dL_dKdiag,X,target):
|
||||
target += 2.*self.variances*dL_dKdiag[:,None]*X
|
||||
|
||||
#---------------------------------------#
|
||||
# PSI statistics #
|
||||
#---------------------------------------#
|
||||
|
|
|
|||
|
|
@ -96,13 +96,13 @@ class rbf(Kernpart):
|
|||
var_len3 = self.variance / np.power(self.lengthscale, 3)
|
||||
if X2 is None:
|
||||
# save computation for the symmetrical case
|
||||
dvardLdK += dvardLdK.T
|
||||
dvardLdK = dvardLdK + dvardLdK.T
|
||||
code = """
|
||||
int q,i,j;
|
||||
double tmp;
|
||||
for(q=0; q<input_dim; q++){
|
||||
tmp = 0;
|
||||
for(i=0; i<N; i++){
|
||||
for(i=0; i<num_data; i++){
|
||||
for(j=0; j<i; j++){
|
||||
tmp += (X(i,q)-X(j,q))*(X(i,q)-X(j,q))*dvardLdK(i,j);
|
||||
}
|
||||
|
|
@ -110,14 +110,15 @@ class rbf(Kernpart):
|
|||
target(q+1) += var_len3(q)*tmp;
|
||||
}
|
||||
"""
|
||||
N, num_inducing, input_dim = X.shape[0], X.shape[0], self.input_dim
|
||||
num_data, num_inducing, input_dim = X.shape[0], X.shape[0], self.input_dim
|
||||
weave.inline(code, arg_names=['num_data','num_inducing','input_dim','X','X2','target','dvardLdK','var_len3'], type_converters=weave.converters.blitz, **self.weave_options)
|
||||
else:
|
||||
code = """
|
||||
int q,i,j;
|
||||
double tmp;
|
||||
for(q=0; q<input_dim; q++){
|
||||
tmp = 0;
|
||||
for(i=0; i<N; i++){
|
||||
for(i=0; i<num_data; i++){
|
||||
for(j=0; j<num_inducing; j++){
|
||||
tmp += (X(i,q)-X2(j,q))*(X(i,q)-X2(j,q))*dvardLdK(i,j);
|
||||
}
|
||||
|
|
@ -125,10 +126,9 @@ class rbf(Kernpart):
|
|||
target(q+1) += var_len3(q)*tmp;
|
||||
}
|
||||
"""
|
||||
N, num_inducing, input_dim = X.shape[0], X2.shape[0], self.input_dim
|
||||
# [np.add(target[1+q:2+q],var_len3[q]*np.sum(dvardLdK*np.square(X[:,q][:,None]-X2[:,q][None,:])),target[1+q:2+q]) for q in range(self.input_dim)]
|
||||
weave.inline(code, arg_names=['N','num_inducing','input_dim','X','X2','target','dvardLdK','var_len3'],
|
||||
type_converters=weave.converters.blitz, **self.weave_options)
|
||||
num_data, num_inducing, input_dim = X.shape[0], X2.shape[0], self.input_dim
|
||||
#[np.add(target[1+q:2+q],var_len3[q]*np.sum(dvardLdK*np.square(X[:,q][:,None]-X2[:,q][None,:])),target[1+q:2+q]) for q in range(self.input_dim)]
|
||||
weave.inline(code, arg_names=['num_data','num_inducing','input_dim','X','X2','target','dvardLdK','var_len3'], type_converters=weave.converters.blitz, **self.weave_options)
|
||||
else:
|
||||
target[1] += (self.variance / self.lengthscale) * np.sum(self._K_dvar * self._K_dist2 * dL_dK)
|
||||
|
||||
|
|
|
|||
|
|
@ -19,41 +19,15 @@ class LinkFunction(object):
|
|||
def __init__(self):
|
||||
pass
|
||||
|
||||
class Identity(LinkFunction):
|
||||
def transf(self,mu):
|
||||
return mu
|
||||
|
||||
def inv_transf(self,f):
|
||||
return f
|
||||
|
||||
def log_inv_transf(self,f):
|
||||
return np.log(f)
|
||||
|
||||
class Log(LinkFunction):
|
||||
|
||||
def transf(self,mu):
|
||||
return np.log(mu)
|
||||
|
||||
def inv_transf(self,f):
|
||||
return np.exp(f)
|
||||
|
||||
def log_inv_transf(self,f):
|
||||
return f
|
||||
|
||||
class Log_ex_1(LinkFunction):
|
||||
def transf(self,mu):
|
||||
return np.log(np.exp(mu) - 1)
|
||||
|
||||
def inv_transf(self,f):
|
||||
return np.log(np.exp(f)+1)
|
||||
|
||||
def log_inv_tranf(self,f):
|
||||
return np.log(np.log(np.exp(f)+1))
|
||||
|
||||
class Probit(LinkFunction):
|
||||
"""
|
||||
Probit link function: Squashes a likelihood between 0 and 1
|
||||
"""
|
||||
def transf(self,mu):
|
||||
pass
|
||||
|
||||
def inv_transf(self,f):
|
||||
return std_norm_cdf(f)
|
||||
pass
|
||||
|
||||
def log_inv_transf(self,f):
|
||||
return np.log(std_norm_cdf(f))
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -11,6 +11,6 @@ import misc
|
|||
import warping_functions
|
||||
import datasets
|
||||
import mocap
|
||||
import visualize
|
||||
#import visualize
|
||||
import decorators
|
||||
import classification
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue