mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 20:12:38 +02:00
fix add kernel and VarDTC_minibatch speed tuning
This commit is contained in:
parent
129985998c
commit
216de32c0c
3 changed files with 18 additions and 9 deletions
|
|
@ -94,6 +94,10 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
|
|
||||||
for n_start in xrange(0,num_data,self.batchsize):
|
for n_start in xrange(0,num_data,self.batchsize):
|
||||||
n_end = min(self.batchsize+n_start, num_data)
|
n_end = min(self.batchsize+n_start, num_data)
|
||||||
|
if (n_end-n_start)==num_data:
|
||||||
|
Y_slice = Y
|
||||||
|
X_slice = X
|
||||||
|
else:
|
||||||
Y_slice = Y[n_start:n_end]
|
Y_slice = Y[n_start:n_end]
|
||||||
X_slice = X[n_start:n_end]
|
X_slice = X[n_start:n_end]
|
||||||
|
|
||||||
|
|
@ -347,7 +351,9 @@ def update_gradients(model, mpi_comm=None):
|
||||||
while not isEnd:
|
while not isEnd:
|
||||||
isEnd, n_range, grad_dict = model.inference_method.inference_minibatch(model.kern, X, model.Z, model.likelihood, Y)
|
isEnd, n_range, grad_dict = model.inference_method.inference_minibatch(model.kern, X, model.Z, model.likelihood, Y)
|
||||||
if isinstance(model.X, VariationalPosterior):
|
if isinstance(model.X, VariationalPosterior):
|
||||||
if mpi_comm ==None:
|
if (n_range[1]-n_range[0])==X.shape[0]:
|
||||||
|
X_slice = X
|
||||||
|
elif mpi_comm ==None:
|
||||||
X_slice = model.X[n_range[0]:n_range[1]]
|
X_slice = model.X[n_range[0]:n_range[1]]
|
||||||
else:
|
else:
|
||||||
X_slice = model.X[model.N_range[0]+n_range[0]:model.N_range[0]+n_range[1]]
|
X_slice = model.X[model.N_range[0]+n_range[0]:model.N_range[0]+n_range[1]]
|
||||||
|
|
|
||||||
|
|
@ -64,12 +64,15 @@ class Add(CombinationKernel):
|
||||||
[target.__iadd__(p.gradients_X_diag(dL_dKdiag, X)) for p in self.parts]
|
[target.__iadd__(p.gradients_X_diag(dL_dKdiag, X)) for p in self.parts]
|
||||||
return target
|
return target
|
||||||
|
|
||||||
|
@Cache_this(limit=2, force_kwargs=['which_parts'])
|
||||||
def psi0(self, Z, variational_posterior):
|
def psi0(self, Z, variational_posterior):
|
||||||
return reduce(np.add, (p.psi0(Z, variational_posterior) for p in self.parts))
|
return reduce(np.add, (p.psi0(Z, variational_posterior) for p in self.parts))
|
||||||
|
|
||||||
|
@Cache_this(limit=2, force_kwargs=['which_parts'])
|
||||||
def psi1(self, Z, variational_posterior):
|
def psi1(self, Z, variational_posterior):
|
||||||
return reduce(np.add, (p.psi1(Z, variational_posterior) for p in self.parts))
|
return reduce(np.add, (p.psi1(Z, variational_posterior) for p in self.parts))
|
||||||
|
|
||||||
|
@Cache_this(limit=2, force_kwargs=['which_parts'])
|
||||||
def psi2(self, Z, variational_posterior):
|
def psi2(self, Z, variational_posterior):
|
||||||
psi2 = reduce(np.add, (p.psi2(Z, variational_posterior) for p in self.parts))
|
psi2 = reduce(np.add, (p.psi2(Z, variational_posterior) for p in self.parts))
|
||||||
#return psi2
|
#return psi2
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import sslinear_psi_comp
|
||||||
|
|
||||||
class PSICOMP_RBF(Pickleable):
|
class PSICOMP_RBF(Pickleable):
|
||||||
|
|
||||||
@Cache_this(limit=1, ignore_args=(0,))
|
@Cache_this(limit=2, ignore_args=(0,))
|
||||||
def psicomputations(self, variance, lengthscale, Z, variational_posterior):
|
def psicomputations(self, variance, lengthscale, Z, variational_posterior):
|
||||||
if isinstance(variational_posterior, variational.NormalPosterior):
|
if isinstance(variational_posterior, variational.NormalPosterior):
|
||||||
return rbf_psi_comp.psicomputations(variance, lengthscale, Z, variational_posterior)
|
return rbf_psi_comp.psicomputations(variance, lengthscale, Z, variational_posterior)
|
||||||
|
|
@ -19,7 +19,7 @@ class PSICOMP_RBF(Pickleable):
|
||||||
else:
|
else:
|
||||||
raise ValueError, "unknown distriubtion received for psi-statistics"
|
raise ValueError, "unknown distriubtion received for psi-statistics"
|
||||||
|
|
||||||
@Cache_this(limit=1, ignore_args=(0,1,2,3))
|
@Cache_this(limit=2, ignore_args=(0,1,2,3))
|
||||||
def psiDerivativecomputations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, lengthscale, Z, variational_posterior):
|
def psiDerivativecomputations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, lengthscale, Z, variational_posterior):
|
||||||
if isinstance(variational_posterior, variational.NormalPosterior):
|
if isinstance(variational_posterior, variational.NormalPosterior):
|
||||||
return rbf_psi_comp.psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, lengthscale, Z, variational_posterior)
|
return rbf_psi_comp.psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, lengthscale, Z, variational_posterior)
|
||||||
|
|
@ -30,7 +30,7 @@ class PSICOMP_RBF(Pickleable):
|
||||||
|
|
||||||
class PSICOMP_Linear(Pickleable):
|
class PSICOMP_Linear(Pickleable):
|
||||||
|
|
||||||
@Cache_this(limit=1, ignore_args=(0,))
|
@Cache_this(limit=2, ignore_args=(0,))
|
||||||
def psicomputations(self, variance, Z, variational_posterior):
|
def psicomputations(self, variance, Z, variational_posterior):
|
||||||
if isinstance(variational_posterior, variational.NormalPosterior):
|
if isinstance(variational_posterior, variational.NormalPosterior):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
@ -39,7 +39,7 @@ class PSICOMP_Linear(Pickleable):
|
||||||
else:
|
else:
|
||||||
raise ValueError, "unknown distriubtion received for psi-statistics"
|
raise ValueError, "unknown distriubtion received for psi-statistics"
|
||||||
|
|
||||||
@Cache_this(limit=1, ignore_args=(0,1,2,3))
|
@Cache_this(limit=2, ignore_args=(0,1,2,3))
|
||||||
def psiDerivativecomputations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, Z, variational_posterior):
|
def psiDerivativecomputations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, Z, variational_posterior):
|
||||||
if isinstance(variational_posterior, variational.NormalPosterior):
|
if isinstance(variational_posterior, variational.NormalPosterior):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue