mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-24 14:15:14 +02:00
added robust pickling, switches to old behaviour, if get/setstate not implemented
This commit is contained in:
parent
05e8e75c58
commit
f5effb8cb6
27 changed files with 392 additions and 283 deletions
|
|
@ -6,18 +6,18 @@ from .. import likelihoods
|
|||
from ..inference import optimization
|
||||
from ..util.linalg import jitchol
|
||||
from GPy.util.misc import opt_wrapper
|
||||
from parameterised import Parameterised
|
||||
from parameterized import Parameterized
|
||||
import multiprocessing as mp
|
||||
import numpy as np
|
||||
from GPy.core.domains import POSITIVE, REAL
|
||||
from numpy.linalg.linalg import LinAlgError
|
||||
# import numdifftools as ndt
|
||||
|
||||
class Model(Parameterised):
|
||||
class Model(Parameterized):
|
||||
_fail_count = 0 # Count of failed optimization steps (see objective)
|
||||
_allowed_failures = 10 # number of allowed failures
|
||||
def __init__(self):
|
||||
Parameterised.__init__(self)
|
||||
Parameterized.__init__(self)
|
||||
self.priors = None
|
||||
self.optimization_runs = []
|
||||
self.sampling_runs = []
|
||||
|
|
@ -32,24 +32,27 @@ class Model(Parameterised):
|
|||
def _log_likelihood_gradients(self):
|
||||
raise NotImplementedError, "this needs to be implemented to use the Model class"
|
||||
|
||||
def __getstate__(self):
|
||||
def getstate(self):
|
||||
"""
|
||||
Get the current state of the class,
|
||||
here just all the indices, rest can get recomputed
|
||||
Get the current state of the class.
|
||||
|
||||
Inherited from Parameterized, so add those parameters to the state
|
||||
"""
|
||||
return Parameterised.__getstate__(self) + \
|
||||
return Parameterized.getstate(self) + \
|
||||
[self.priors, self.optimization_runs,
|
||||
self.sampling_runs, self.preferred_optimizer]
|
||||
|
||||
def __setstate__(self, state):
|
||||
def setstate(self, state):
|
||||
"""
|
||||
set state from previous call to getstate
|
||||
|
||||
call Parameterized with the rest of the state
|
||||
"""
|
||||
self.preferred_optimizer = state.pop()
|
||||
self.sampling_runs = state.pop()
|
||||
self.optimization_runs = state.pop()
|
||||
self.priors = state.pop()
|
||||
Parameterised.__setstate__(self, state)
|
||||
Parameterized.setstate(self, state)
|
||||
|
||||
def set_prior(self, regexp, what):
|
||||
"""
|
||||
|
|
@ -355,7 +358,7 @@ class Model(Parameterised):
|
|||
return 0.5 * self._get_params().size * np.log(2 * np.pi) + self.log_likelihood() - hld
|
||||
|
||||
def __str__(self):
|
||||
s = Parameterised.__str__(self).split('\n')
|
||||
s = Parameterized.__str__(self).split('\n')
|
||||
# add priors to the string
|
||||
if self.priors is not None:
|
||||
strs = [str(p) if p is not None else '' for p in self.priors]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue