mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-08 19:42:39 +02:00
108 lines
4.7 KiB
Python
108 lines
4.7 KiB
Python
|
|
'''
|
||
|
|
Created on 27 Feb 2014
|
||
|
|
|
||
|
|
@author: maxz
|
||
|
|
'''
|
||
|
|
import unittest
|
||
|
|
from GPy.core.parameterization.parameterized import Parameterized
|
||
|
|
from GPy.core.parameterization.param import Param
|
||
|
|
import numpy
|
||
|
|
|
||
|
|
|
||
|
|
class ParamTestParent(Parameterized):
|
||
|
|
parent_changed_count = 0
|
||
|
|
def parameters_changed(self):
|
||
|
|
self.parent_changed_count += 1
|
||
|
|
|
||
|
|
class ParameterizedTest(Parameterized):
|
||
|
|
params_changed_count = 0
|
||
|
|
def parameters_changed(self):
|
||
|
|
self.params_changed_count += 1
|
||
|
|
|
||
|
|
class Test(unittest.TestCase):
|
||
|
|
|
||
|
|
def setUp(self):
|
||
|
|
self.parent = ParamTestParent('test parent')
|
||
|
|
self.par = ParameterizedTest('test model')
|
||
|
|
self.p = Param('test parameter', numpy.random.normal(1,2,(10,3)))
|
||
|
|
|
||
|
|
self.par.add_parameter(self.p)
|
||
|
|
self.parent.add_parameter(self.par)
|
||
|
|
|
||
|
|
self._observer_triggered = None
|
||
|
|
self._trigger_count = 0
|
||
|
|
self._first = None
|
||
|
|
self._second = None
|
||
|
|
|
||
|
|
def _trigger(self, which):
|
||
|
|
self._observer_triggered = float(which)
|
||
|
|
self._trigger_count += 1
|
||
|
|
if self._first is not None:
|
||
|
|
self._second = self._trigger
|
||
|
|
else:
|
||
|
|
self._first = self._trigger
|
||
|
|
|
||
|
|
def _trigger_priority(self, which):
|
||
|
|
if self._first is not None:
|
||
|
|
self._second = self._trigger_priority
|
||
|
|
else:
|
||
|
|
self._first = self._trigger_priority
|
||
|
|
|
||
|
|
def test_observable(self):
|
||
|
|
self.par.add_observer(self, self._trigger, -1)
|
||
|
|
self.assertEqual(self.par.params_changed_count, 0, 'no params changed yet')
|
||
|
|
self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param')
|
||
|
|
|
||
|
|
self.p[0,1] = 3 # trigger observers
|
||
|
|
self.assertEqual(self._observer_triggered, 3, 'observer should have triggered')
|
||
|
|
self.assertEqual(self._trigger_count, 1, 'observer should have triggered once')
|
||
|
|
self.assertEqual(self.par.params_changed_count, 1, 'params changed once')
|
||
|
|
self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param')
|
||
|
|
|
||
|
|
self.par.remove_observer(self)
|
||
|
|
self.p[2,1] = 4
|
||
|
|
self.assertEqual(self._observer_triggered, 3, 'observer should not have triggered')
|
||
|
|
self.assertEqual(self._trigger_count, 1, 'observer should have triggered once')
|
||
|
|
self.assertEqual(self.par.params_changed_count, 2, 'params changed second')
|
||
|
|
self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param')
|
||
|
|
|
||
|
|
self.par.add_observer(self, self._trigger, -1)
|
||
|
|
self.p[2,1] = 4
|
||
|
|
self.assertEqual(self._observer_triggered, 4, 'observer should have triggered')
|
||
|
|
self.assertEqual(self._trigger_count, 2, 'observer should have triggered once')
|
||
|
|
self.assertEqual(self.par.params_changed_count, 3, 'params changed second')
|
||
|
|
self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param')
|
||
|
|
|
||
|
|
self.par.remove_observer(self, self._trigger)
|
||
|
|
self.p[0,1] = 3
|
||
|
|
self.assertEqual(self._observer_triggered, 4, 'observer should not have triggered')
|
||
|
|
self.assertEqual(self._trigger_count, 2, 'observer should have triggered once')
|
||
|
|
self.assertEqual(self.par.params_changed_count, 4, 'params changed second')
|
||
|
|
self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param')
|
||
|
|
|
||
|
|
def test_set_params(self):
|
||
|
|
self.assertEqual(self.par.params_changed_count, 0, 'no params changed yet')
|
||
|
|
self.par._set_params(numpy.ones(self.par.size))
|
||
|
|
self.assertEqual(self.par.params_changed_count, 1, 'now params changed')
|
||
|
|
self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param')
|
||
|
|
|
||
|
|
def test_priority(self):
|
||
|
|
self.par.add_observer(self, self._trigger, -1)
|
||
|
|
self.par.add_observer(self, self._trigger_priority, 0)
|
||
|
|
self.par._notify_observers(0)
|
||
|
|
self.assertEqual(self._first, self._trigger_priority, 'priority should be first')
|
||
|
|
self.assertEqual(self._second, self._trigger, 'priority should be first')
|
||
|
|
|
||
|
|
self.par.remove_observer(self)
|
||
|
|
self._first = self._second = None
|
||
|
|
|
||
|
|
self.par.add_observer(self, self._trigger, 1)
|
||
|
|
self.par.add_observer(self, self._trigger_priority, 0)
|
||
|
|
self.par._notify_observers(0)
|
||
|
|
self.assertEqual(self._first, self._trigger, 'priority should be second')
|
||
|
|
self.assertEqual(self._second, self._trigger_priority, 'priority should be second')
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
#import sys;sys.argv = ['', 'Test.testName']
|
||
|
|
unittest.main()
|