mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-13 14:03:20 +02:00
[updates] made updates a function
This commit is contained in:
parent
32469e0461
commit
9a74a933f3
3 changed files with 45 additions and 19 deletions
|
|
@ -249,7 +249,7 @@ class Param(Parameterizable, ObsAr):
|
||||||
try:
|
try:
|
||||||
indices = np.indices(self._realshape_, dtype=int)
|
indices = np.indices(self._realshape_, dtype=int)
|
||||||
indices = indices[(slice(None),)+slice_index]
|
indices = indices[(slice(None),)+slice_index]
|
||||||
indices = np.rollaxis(indices, 0, indices.ndim).reshape(-1,2)
|
indices = np.rollaxis(indices, 0, indices.ndim).reshape(-1,self._realndim_)
|
||||||
#print indices_
|
#print indices_
|
||||||
#if not np.all(indices==indices__):
|
#if not np.all(indices==indices__):
|
||||||
# import ipdb; ipdb.set_trace()
|
# import ipdb; ipdb.set_trace()
|
||||||
|
|
|
||||||
|
|
@ -50,30 +50,50 @@ class Observable(object):
|
||||||
as an observer. Every time the observable changes, it sends a notification with
|
as an observer. Every time the observable changes, it sends a notification with
|
||||||
self as only argument to all its observers.
|
self as only argument to all its observers.
|
||||||
"""
|
"""
|
||||||
_updates = True
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Observable, self).__init__()
|
super(Observable, self).__init__()
|
||||||
from lists_and_dicts import ObserverList
|
from lists_and_dicts import ObserverList
|
||||||
self.observers = ObserverList()
|
self.observers = ObserverList()
|
||||||
|
self._updates = True
|
||||||
|
|
||||||
@property
|
def updates(self, updates=None):
|
||||||
def updates(self):
|
"""
|
||||||
|
Get or set, whether automatic updates are performed. When updates are
|
||||||
|
off, the model might be in a non-working state. To make the model work
|
||||||
|
turn updates on again.
|
||||||
|
|
||||||
|
:param bool|None updates:
|
||||||
|
|
||||||
|
bool: whether to do updates
|
||||||
|
None: get the current update state
|
||||||
|
"""
|
||||||
|
if updates is None:
|
||||||
|
p = getattr(self, '_highest_parent_', None)
|
||||||
|
if p is not None:
|
||||||
|
self._updates = p._updates
|
||||||
|
return self._updates
|
||||||
|
assert isinstance(updates, bool), "updates are either on (True) or off (False)"
|
||||||
p = getattr(self, '_highest_parent_', None)
|
p = getattr(self, '_highest_parent_', None)
|
||||||
if p is not None:
|
if p is not None:
|
||||||
self._updates = p._updates
|
p._updates = updates
|
||||||
return self._updates
|
|
||||||
|
|
||||||
@updates.setter
|
|
||||||
def updates(self, ups):
|
|
||||||
assert isinstance(ups, bool), "updates are either on (True) or off (False)"
|
|
||||||
p = getattr(self, '_highest_parent_', None)
|
|
||||||
if p is not None:
|
|
||||||
p._updates = ups
|
|
||||||
else:
|
else:
|
||||||
self._updates = ups
|
self._updates = updates
|
||||||
if ups:
|
self.update_model()
|
||||||
self._trigger_params_changed()
|
|
||||||
|
def toggle_updates(self):
|
||||||
|
self.updates(not self.updates())
|
||||||
|
|
||||||
|
def update_model(self):
|
||||||
|
"""
|
||||||
|
Update the model from the current state.
|
||||||
|
Make sure that updates are on, otherwise this
|
||||||
|
method will do nothing
|
||||||
|
"""
|
||||||
|
if not self.updates():
|
||||||
|
#print "Warning: updates are off, updating the model will do nothing"
|
||||||
|
return
|
||||||
|
self._trigger_params_changed()
|
||||||
|
|
||||||
def add_observer(self, observer, callble, priority=0):
|
def add_observer(self, observer, callble, priority=0):
|
||||||
"""
|
"""
|
||||||
Add an observer `observer` with the callback `callble`
|
Add an observer `observer` with the callback `callble`
|
||||||
|
|
@ -110,7 +130,7 @@ class Observable(object):
|
||||||
:param min_priority: only notify observers with priority > min_priority
|
:param min_priority: only notify observers with priority > min_priority
|
||||||
if min_priority is None, notify all observers in order
|
if min_priority is None, notify all observers in order
|
||||||
"""
|
"""
|
||||||
if not self.updates:
|
if not self.updates():
|
||||||
return
|
return
|
||||||
if which is None:
|
if which is None:
|
||||||
which = self
|
which = self
|
||||||
|
|
@ -798,7 +818,7 @@ class OptimizationHandlable(Indexable):
|
||||||
"""
|
"""
|
||||||
# first take care of all parameters (from N(0,1))
|
# first take care of all parameters (from N(0,1))
|
||||||
x = rand_gen(size=self._size_transformed(), *args, **kwargs)
|
x = rand_gen(size=self._size_transformed(), *args, **kwargs)
|
||||||
self.updates = False # Switch off the updates
|
self.updates(False) # Switch off the updates
|
||||||
self.optimizer_array = x # makes sure all of the tied parameters get the same init (since there's only one prior object...)
|
self.optimizer_array = x # makes sure all of the tied parameters get the same init (since there's only one prior object...)
|
||||||
# now draw from prior where possible
|
# now draw from prior where possible
|
||||||
x = self.param_array.copy()
|
x = self.param_array.copy()
|
||||||
|
|
@ -806,7 +826,7 @@ class OptimizationHandlable(Indexable):
|
||||||
unfixlist = np.ones((self.size,),dtype=np.bool)
|
unfixlist = np.ones((self.size,),dtype=np.bool)
|
||||||
unfixlist[self.constraints[__fixed__]] = False
|
unfixlist[self.constraints[__fixed__]] = False
|
||||||
self.param_array[unfixlist] = x[unfixlist]
|
self.param_array[unfixlist] = x[unfixlist]
|
||||||
self.updates = True
|
self.updates(True)
|
||||||
|
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# For shared memory arrays. This does nothing in Param, but sets the memory
|
# For shared memory arrays. This does nothing in Param, but sets the memory
|
||||||
|
|
|
||||||
|
|
@ -152,6 +152,12 @@ class ParameterizedTest(unittest.TestCase):
|
||||||
self.test1.kern.randomize()
|
self.test1.kern.randomize()
|
||||||
self.assertEqual(val, self.rbf.variance)
|
self.assertEqual(val, self.rbf.variance)
|
||||||
|
|
||||||
|
def test_updates(self):
|
||||||
|
self.test1.updates = False
|
||||||
|
val = float(self.rbf.variance)
|
||||||
|
self.test1.kern.randomize()
|
||||||
|
self.assertEqual(val, self.rbf.variance)
|
||||||
|
|
||||||
def test_fixing_optimize(self):
|
def test_fixing_optimize(self):
|
||||||
self.testmodel.kern.lengthscale.fix()
|
self.testmodel.kern.lengthscale.fix()
|
||||||
val = float(self.testmodel.kern.lengthscale)
|
val = float(self.testmodel.kern.lengthscale)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue