mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-21 14:05:14 +02:00
finish the debug of sparsegp_mpi
This commit is contained in:
parent
31f11eda40
commit
d7eee6aa00
2 changed files with 10 additions and 20 deletions
|
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sparse_gp import SparseGP
|
from sparse_gp import SparseGP
|
||||||
|
from numpy.linalg.linalg import LinAlgError
|
||||||
from ..inference.latent_function_inference.var_dtc_parallel import update_gradients, VarDTC_minibatch
|
from ..inference.latent_function_inference.var_dtc_parallel import update_gradients, VarDTC_minibatch
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -83,11 +84,7 @@ class SparseGP_MPI(SparseGP):
|
||||||
if self.mpi_comm != None:
|
if self.mpi_comm != None:
|
||||||
if self._IN_OPTIMIZATION_ and self.mpi_comm.rank==0:
|
if self._IN_OPTIMIZATION_ and self.mpi_comm.rank==0:
|
||||||
self.mpi_comm.Bcast(np.int32(1),root=0)
|
self.mpi_comm.Bcast(np.int32(1),root=0)
|
||||||
self.mpi_comm.Bcast(p, root=0)
|
self.mpi_comm.Bcast(p, root=0)
|
||||||
|
|
||||||
from ..util.debug import checkFinite
|
|
||||||
checkFinite(p, 'optimizer_array')
|
|
||||||
|
|
||||||
SparseGP.optimizer_array.fset(self,p)
|
SparseGP.optimizer_array.fset(self,p)
|
||||||
|
|
||||||
def optimize(self, optimizer=None, start=None, **kwargs):
|
def optimize(self, optimizer=None, start=None, **kwargs):
|
||||||
|
|
@ -103,7 +100,13 @@ class SparseGP_MPI(SparseGP):
|
||||||
while True:
|
while True:
|
||||||
self.mpi_comm.Bcast(flag,root=0)
|
self.mpi_comm.Bcast(flag,root=0)
|
||||||
if flag==1:
|
if flag==1:
|
||||||
self.optimizer_array = x
|
try:
|
||||||
|
self.optimizer_array = x
|
||||||
|
self._fail_count = 0
|
||||||
|
except (LinAlgError, ZeroDivisionError, ValueError):
|
||||||
|
if self._fail_count >= self._allowed_failures:
|
||||||
|
raise
|
||||||
|
self._fail_count += 1
|
||||||
elif flag==-1:
|
elif flag==-1:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -167,18 +167,12 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
# Compute Common Components
|
# Compute Common Components
|
||||||
#======================================================================
|
#======================================================================
|
||||||
|
|
||||||
from ...util.debug import checkFullRank
|
|
||||||
|
|
||||||
Kmm = kern.K(Z).copy()
|
Kmm = kern.K(Z).copy()
|
||||||
diag.add(Kmm, self.const_jitter)
|
diag.add(Kmm, self.const_jitter)
|
||||||
r1 = checkFullRank(Kmm,name='Kmm')
|
|
||||||
KmmInv,Lm,LmInv,_ = pdinv(Kmm)
|
KmmInv,Lm,LmInv,_ = pdinv(Kmm)
|
||||||
|
|
||||||
LmInvPsi2LmInvT = LmInv.dot(psi2_full).dot(LmInv.T)
|
LmInvPsi2LmInvT = LmInv.dot(psi2_full).dot(LmInv.T)
|
||||||
Lambda = np.eye(Kmm.shape[0])+LmInvPsi2LmInvT
|
Lambda = np.eye(Kmm.shape[0])+LmInvPsi2LmInvT
|
||||||
r2 = checkFullRank(Lambda,name='Lambda')
|
|
||||||
# if (not r1) or (not r2):
|
|
||||||
# raise
|
|
||||||
LInv,LL,LLInv,logdet_L = pdinv(Lambda)
|
LInv,LL,LLInv,logdet_L = pdinv(Lambda)
|
||||||
b = LLInv.dot(LmInv.dot(psi1Y_full.T))
|
b = LLInv.dot(LmInv.dot(psi1Y_full.T))
|
||||||
bbt = np.square(b).sum()
|
bbt = np.square(b).sum()
|
||||||
|
|
@ -203,7 +197,6 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
# Compute dL_dKmm
|
# Compute dL_dKmm
|
||||||
#======================================================================
|
#======================================================================
|
||||||
|
|
||||||
# dL_dKmm = -(output_dim*np.einsum('md,od->mo',KmmInvPsi2LLInvT,KmmInvPsi2LLInvT) + vvt)/2.
|
|
||||||
dL_dKmm = dL_dpsi2R - output_dim*KmmInv.dot(psi2_full).dot(KmmInv)/2.
|
dL_dKmm = dL_dpsi2R - output_dim*KmmInv.dot(psi2_full).dot(KmmInv)/2.
|
||||||
|
|
||||||
#======================================================================
|
#======================================================================
|
||||||
|
|
@ -336,13 +329,7 @@ def update_gradients(model, mpi_comm=None):
|
||||||
Y = model.Y_local
|
Y = model.Y_local
|
||||||
X = model.X[model.N_range[0]:model.N_range[1]]
|
X = model.X[model.N_range[0]:model.N_range[1]]
|
||||||
|
|
||||||
try:
|
model._log_marginal_likelihood, dL_dKmm, model.posterior = model.inference_method.inference_likelihood(model.kern, X, model.Z, model.likelihood, Y)
|
||||||
model._log_marginal_likelihood, dL_dKmm, model.posterior = model.inference_method.inference_likelihood(model.kern, X, model.Z, model.likelihood, Y)
|
|
||||||
except Exception:
|
|
||||||
if model.mpi_comm is None or model.mpi_comm.rank==0:
|
|
||||||
import time
|
|
||||||
model.pickle('model_'+str(int(time.time()))+'.pickle')
|
|
||||||
raise
|
|
||||||
|
|
||||||
het_noise = model.likelihood.variance.size > 1
|
het_noise = model.likelihood.variance.size > 1
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue