mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-21 14:05:14 +02:00
pickling working for array-likes, but observers not yet connected back
This commit is contained in:
parent
19dc7cecf4
commit
2ce3a93b3f
5 changed files with 35 additions and 21 deletions
|
|
@ -1,7 +1,7 @@
|
||||||
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
|
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
|
||||||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||||
|
|
||||||
__updated__ = '2013-12-16'
|
__updated__ = '2014-03-17'
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from parameter_core import Observable
|
from parameter_core import Observable
|
||||||
|
|
@ -18,7 +18,7 @@ class ObservableArray(np.ndarray, Observable):
|
||||||
if not isinstance(input_array, ObservableArray):
|
if not isinstance(input_array, ObservableArray):
|
||||||
obj = np.atleast_1d(np.require(input_array, dtype=np.float64, requirements=['W', 'C'])).view(cls)
|
obj = np.atleast_1d(np.require(input_array, dtype=np.float64, requirements=['W', 'C'])).view(cls)
|
||||||
else: obj = input_array
|
else: obj = input_array
|
||||||
cls.__name__ = "ObservableArray\n "
|
cls.__name__ = "ObsAr" # because of fixed printing of `array` in np printing
|
||||||
super(ObservableArray, obj).__init__(*a, **kw)
|
super(ObservableArray, obj).__init__(*a, **kw)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
@ -30,6 +30,14 @@ class ObservableArray(np.ndarray, Observable):
|
||||||
def __array_wrap__(self, out_arr, context=None):
|
def __array_wrap__(self, out_arr, context=None):
|
||||||
return out_arr.view(np.ndarray)
|
return out_arr.view(np.ndarray)
|
||||||
|
|
||||||
|
def __reduce__(self):
|
||||||
|
func, args, state = np.ndarray.__reduce__(self)
|
||||||
|
return func, args, (state, Observable._getstate(self))
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
np.ndarray.__setstate__(self, state[0])
|
||||||
|
Observable._setstate(self, state[1])
|
||||||
|
|
||||||
def _s_not_empty(self, s):
|
def _s_not_empty(self, s):
|
||||||
# this checks whether there is something picked by this slice.
|
# this checks whether there is something picked by this slice.
|
||||||
return True
|
return True
|
||||||
|
|
|
||||||
|
|
@ -269,6 +269,8 @@ class Param(OptimizationHandlable, ObservableArray):
|
||||||
@property
|
@property
|
||||||
def _ties_str(self):
|
def _ties_str(self):
|
||||||
return ['']
|
return ['']
|
||||||
|
def _ties_for(self, ravi):
|
||||||
|
return [['N/A']]*ravi.size
|
||||||
def __repr__(self, *args, **kwargs):
|
def __repr__(self, *args, **kwargs):
|
||||||
name = "\033[1m{x:s}\033[0;0m:\n".format(
|
name = "\033[1m{x:s}\033[0;0m:\n".format(
|
||||||
x=self.hierarchy_name())
|
x=self.hierarchy_name())
|
||||||
|
|
@ -312,7 +314,7 @@ class Param(OptimizationHandlable, ObservableArray):
|
||||||
ravi = self._raveled_index(filter_)
|
ravi = self._raveled_index(filter_)
|
||||||
if constr_matrix is None: constr_matrix = self.constraints.properties_for(ravi)
|
if constr_matrix is None: constr_matrix = self.constraints.properties_for(ravi)
|
||||||
if prirs is None: prirs = self.priors.properties_for(ravi)
|
if prirs is None: prirs = self.priors.properties_for(ravi)
|
||||||
if ties is None: ties = [['N/A']]*self.size
|
if ties is None: ties = self._ties_for(ravi)
|
||||||
ties = [' '.join(map(lambda x: x, t)) for t in ties]
|
ties = [' '.join(map(lambda x: x, t)) for t in ties]
|
||||||
if lc is None: lc = self._max_len_names(constr_matrix, __constraints_name__)
|
if lc is None: lc = self._max_len_names(constr_matrix, __constraints_name__)
|
||||||
if lx is None: lx = self._max_len_values()
|
if lx is None: lx = self._max_len_values()
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ Observable Pattern for patameterization
|
||||||
from transformations import Transformation, Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED
|
from transformations import Transformation, Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
__updated__ = '2014-03-14'
|
__updated__ = '2014-03-17'
|
||||||
|
|
||||||
class HierarchyError(Exception):
|
class HierarchyError(Exception):
|
||||||
"""
|
"""
|
||||||
|
|
@ -56,7 +56,7 @@ class InterfacePickleFunctions(object):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError, "To be able to use pickling you need to implement this method"
|
raise NotImplementedError, "To be able to use pickling you need to implement this method"
|
||||||
|
|
||||||
class Pickleable(object):
|
class Pickleable(InterfacePickleFunctions):
|
||||||
"""
|
"""
|
||||||
Make an object pickleable (See python doc 'pickling').
|
Make an object pickleable (See python doc 'pickling').
|
||||||
|
|
||||||
|
|
@ -95,7 +95,7 @@ class Pickleable(object):
|
||||||
def _has_get_set_state(self):
|
def _has_get_set_state(self):
|
||||||
return '_getstate' in vars(self.__class__) and '_setstate' in vars(self.__class__)
|
return '_getstate' in vars(self.__class__) and '_setstate' in vars(self.__class__)
|
||||||
|
|
||||||
class Observable(InterfacePickleFunctions):
|
class Observable(Pickleable):
|
||||||
"""
|
"""
|
||||||
Observable pattern for parameterization.
|
Observable pattern for parameterization.
|
||||||
|
|
||||||
|
|
@ -155,6 +155,7 @@ class Observable(InterfacePickleFunctions):
|
||||||
|
|
||||||
def _getstate(self):
|
def _getstate(self):
|
||||||
return [self._observer_callables_]
|
return [self._observer_callables_]
|
||||||
|
|
||||||
def _setstate(self, state):
|
def _setstate(self, state):
|
||||||
self._observer_callables_ = state.pop()
|
self._observer_callables_ = state.pop()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from GPy.core.parameterization.parameterized import Parameterized
|
||||||
from GPy.core.parameterization.param import Param
|
from GPy.core.parameterization.param import Param
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
# One trigger in init
|
# One trigger in init
|
||||||
_trigger_start = -1
|
_trigger_start = -1
|
||||||
|
|
||||||
class ParamTestParent(Parameterized):
|
class ParamTestParent(Parameterized):
|
||||||
|
|
@ -21,11 +21,9 @@ class ParameterizedTest(Parameterized):
|
||||||
params_changed_count = _trigger_start
|
params_changed_count = _trigger_start
|
||||||
def parameters_changed(self):
|
def parameters_changed(self):
|
||||||
self.params_changed_count += 1
|
self.params_changed_count += 1
|
||||||
def _set_params(self, params, trigger_parent=True):
|
|
||||||
Parameterized._set_params(self, params, trigger_parent=trigger_parent)
|
|
||||||
|
|
||||||
class Test(unittest.TestCase):
|
class Test(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.parent = ParamTestParent('test parent')
|
self.parent = ParamTestParent('test parent')
|
||||||
self.par = ParameterizedTest('test model')
|
self.par = ParameterizedTest('test model')
|
||||||
|
|
@ -41,12 +39,12 @@ class Test(unittest.TestCase):
|
||||||
|
|
||||||
self.parent.add_parameter(self.par)
|
self.parent.add_parameter(self.par)
|
||||||
self.parent.add_parameter(self.par2)
|
self.parent.add_parameter(self.par2)
|
||||||
|
|
||||||
self._observer_triggered = None
|
self._observer_triggered = None
|
||||||
self._trigger_count = 0
|
self._trigger_count = 0
|
||||||
self._first = None
|
self._first = None
|
||||||
self._second = None
|
self._second = None
|
||||||
|
|
||||||
def _trigger(self, which):
|
def _trigger(self, which):
|
||||||
self._observer_triggered = float(which)
|
self._observer_triggered = float(which)
|
||||||
self._trigger_count += 1
|
self._trigger_count += 1
|
||||||
|
|
@ -54,18 +52,18 @@ class Test(unittest.TestCase):
|
||||||
self._second = self._trigger
|
self._second = self._trigger
|
||||||
else:
|
else:
|
||||||
self._first = self._trigger
|
self._first = self._trigger
|
||||||
|
|
||||||
def _trigger_priority(self, which):
|
def _trigger_priority(self, which):
|
||||||
if self._first is not None:
|
if self._first is not None:
|
||||||
self._second = self._trigger_priority
|
self._second = self._trigger_priority
|
||||||
else:
|
else:
|
||||||
self._first = self._trigger_priority
|
self._first = self._trigger_priority
|
||||||
|
|
||||||
def test_observable(self):
|
def test_observable(self):
|
||||||
self.par.add_observer(self, self._trigger, -1)
|
self.par.add_observer(self, self._trigger, -1)
|
||||||
self.assertEqual(self.par.params_changed_count, 0, 'no params changed yet')
|
self.assertEqual(self.par.params_changed_count, 0, 'no params changed yet')
|
||||||
self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param')
|
self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param')
|
||||||
|
|
||||||
self.p[0,1] = 3 # trigger observers
|
self.p[0,1] = 3 # trigger observers
|
||||||
self.assertEqual(self._observer_triggered, 3, 'observer should have triggered')
|
self.assertEqual(self._observer_triggered, 3, 'observer should have triggered')
|
||||||
self.assertEqual(self._trigger_count, 1, 'observer should have triggered once')
|
self.assertEqual(self._trigger_count, 1, 'observer should have triggered once')
|
||||||
|
|
@ -78,14 +76,14 @@ class Test(unittest.TestCase):
|
||||||
self.assertEqual(self._trigger_count, 1, 'observer should have triggered once')
|
self.assertEqual(self._trigger_count, 1, 'observer should have triggered once')
|
||||||
self.assertEqual(self.par.params_changed_count, 2, 'params changed second')
|
self.assertEqual(self.par.params_changed_count, 2, 'params changed second')
|
||||||
self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param')
|
self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param')
|
||||||
|
|
||||||
self.par.add_observer(self, self._trigger, -1)
|
self.par.add_observer(self, self._trigger, -1)
|
||||||
self.p[2,1] = 4
|
self.p[2,1] = 4
|
||||||
self.assertEqual(self._observer_triggered, 4, 'observer should have triggered')
|
self.assertEqual(self._observer_triggered, 4, 'observer should have triggered')
|
||||||
self.assertEqual(self._trigger_count, 2, 'observer should have triggered once')
|
self.assertEqual(self._trigger_count, 2, 'observer should have triggered once')
|
||||||
self.assertEqual(self.par.params_changed_count, 3, 'params changed second')
|
self.assertEqual(self.par.params_changed_count, 3, 'params changed second')
|
||||||
self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param')
|
self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param')
|
||||||
|
|
||||||
self.par.remove_observer(self, self._trigger)
|
self.par.remove_observer(self, self._trigger)
|
||||||
self.p[0,1] = 3
|
self.p[0,1] = 3
|
||||||
self.assertEqual(self._observer_triggered, 4, 'observer should not have triggered')
|
self.assertEqual(self._observer_triggered, 4, 'observer should not have triggered')
|
||||||
|
|
@ -99,7 +97,7 @@ class Test(unittest.TestCase):
|
||||||
self.par._trigger_params_changed()
|
self.par._trigger_params_changed()
|
||||||
self.assertEqual(self.par.params_changed_count, 1, 'now params changed')
|
self.assertEqual(self.par.params_changed_count, 1, 'now params changed')
|
||||||
self.assertEqual(self.parent.parent_changed_count, self.par.params_changed_count)
|
self.assertEqual(self.parent.parent_changed_count, self.par.params_changed_count)
|
||||||
|
|
||||||
self.par._param_array_[:] = 2
|
self.par._param_array_[:] = 2
|
||||||
self.par._trigger_params_changed()
|
self.par._trigger_params_changed()
|
||||||
self.assertEqual(self.par.params_changed_count, 2, 'now params changed')
|
self.assertEqual(self.par.params_changed_count, 2, 'now params changed')
|
||||||
|
|
@ -125,13 +123,13 @@ class Test(unittest.TestCase):
|
||||||
|
|
||||||
self.par.remove_observer(self)
|
self.par.remove_observer(self)
|
||||||
self._first = self._second = None
|
self._first = self._second = None
|
||||||
|
|
||||||
self.par.add_observer(self, self._trigger, 1)
|
self.par.add_observer(self, self._trigger, 1)
|
||||||
self.par.add_observer(self, self._trigger_priority, 0)
|
self.par.add_observer(self, self._trigger_priority, 0)
|
||||||
self.par.notify_observers(0)
|
self.par.notify_observers(0)
|
||||||
self.assertEqual(self._first, self._trigger, 'priority should be second')
|
self.assertEqual(self._first, self._trigger, 'priority should be second')
|
||||||
self.assertEqual(self._second, self._trigger_priority, 'priority should be second')
|
self.assertEqual(self._second, self._trigger_priority, 'priority should be second')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
#import sys;sys.argv = ['', 'Test.testName']
|
#import sys;sys.argv = ['', 'Test.testName']
|
||||||
|
|
|
||||||
|
|
@ -108,7 +108,7 @@ class ParameterizedTest(unittest.TestCase):
|
||||||
self.assertEqual(self.param.constraints._offset, 3)
|
self.assertEqual(self.param.constraints._offset, 3)
|
||||||
|
|
||||||
def test_fixing_randomize(self):
|
def test_fixing_randomize(self):
|
||||||
self.white.fix(warning=False)
|
self.white.fix(warning=True)
|
||||||
val = float(self.test1.white.variance)
|
val = float(self.test1.white.variance)
|
||||||
self.test1.randomize()
|
self.test1.randomize()
|
||||||
self.assertEqual(val, self.white.variance)
|
self.assertEqual(val, self.white.variance)
|
||||||
|
|
@ -119,6 +119,11 @@ class ParameterizedTest(unittest.TestCase):
|
||||||
self.testmodel.randomize()
|
self.testmodel.randomize()
|
||||||
self.assertEqual(val, self.testmodel.kern.lengthscale)
|
self.assertEqual(val, self.testmodel.kern.lengthscale)
|
||||||
|
|
||||||
|
def test_printing(self):
|
||||||
|
print self.test1
|
||||||
|
print self.param
|
||||||
|
print self.test1['']
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
#import sys;sys.argv = ['', 'Test.test_add_parameter']
|
#import sys;sys.argv = ['', 'Test.test_add_parameter']
|
||||||
unittest.main()
|
unittest.main()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue