mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-05 14:55:15 +02:00
finish the debug of sparsegp_mpi
This commit is contained in:
parent
31f11eda40
commit
d7eee6aa00
2 changed files with 10 additions and 20 deletions
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
import numpy as np
|
||||
from sparse_gp import SparseGP
|
||||
from numpy.linalg.linalg import LinAlgError
|
||||
from ..inference.latent_function_inference.var_dtc_parallel import update_gradients, VarDTC_minibatch
|
||||
|
||||
import logging
|
||||
|
|
@ -83,11 +84,7 @@ class SparseGP_MPI(SparseGP):
|
|||
if self.mpi_comm != None:
|
||||
if self._IN_OPTIMIZATION_ and self.mpi_comm.rank==0:
|
||||
self.mpi_comm.Bcast(np.int32(1),root=0)
|
||||
self.mpi_comm.Bcast(p, root=0)
|
||||
|
||||
from ..util.debug import checkFinite
|
||||
checkFinite(p, 'optimizer_array')
|
||||
|
||||
self.mpi_comm.Bcast(p, root=0)
|
||||
SparseGP.optimizer_array.fset(self,p)
|
||||
|
||||
def optimize(self, optimizer=None, start=None, **kwargs):
|
||||
|
|
@ -103,7 +100,13 @@ class SparseGP_MPI(SparseGP):
|
|||
while True:
|
||||
self.mpi_comm.Bcast(flag,root=0)
|
||||
if flag==1:
|
||||
self.optimizer_array = x
|
||||
try:
|
||||
self.optimizer_array = x
|
||||
self._fail_count = 0
|
||||
except (LinAlgError, ZeroDivisionError, ValueError):
|
||||
if self._fail_count >= self._allowed_failures:
|
||||
raise
|
||||
self._fail_count += 1
|
||||
elif flag==-1:
|
||||
break
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue