[updates] updated update structure immensely

This commit is contained in:
Max Zwiessele 2014-09-22 09:26:01 +01:00
parent 4543fc3480
commit ef05f49b8b
5 changed files with 99 additions and 53 deletions

View file

@ -49,7 +49,7 @@ class Param(Parameterizable, ObsAr):
obj._realshape_ = obj.shape
obj._realsize_ = obj.size
obj._realndim_ = obj.ndim
obj._original_ = True
obj._original_ = obj
return obj
def __init__(self, name, input_array, default_constraint=None, *a, **kw):
@ -124,10 +124,10 @@ class Param(Parameterizable, ObsAr):
#if not reduce(lambda a, b: a or numpy.any(b is Ellipsis), s, False) and len(s) <= self.ndim:
# s += (Ellipsis,)
new_arr = super(Param, self).__getitem__(s, *args, **kwargs)
try:
try:
new_arr._current_slice_ = s
new_arr._gradient_array_ = self.gradient[s]
new_arr._original_ = self.base is new_arr.base
new_arr._original_ = self._original_
except AttributeError: pass # returning 0d array or float, double etc
return new_arr
@ -157,29 +157,29 @@ class Param(Parameterizable, ObsAr):
return self.constraints[__fixed__].size == self.size
def _get_original(self, param):
return self
return self._original_
#===========================================================================
# Pickling and copying
#===========================================================================
def copy(self):
return Parameterizable.copy(self, which=self)
def __deepcopy__(self, memo):
s = self.__new__(self.__class__, name=self.name, input_array=self.view(numpy.ndarray).copy())
memo[id(self)] = s
memo[id(self)] = s
import copy
Pickleable.__setstate__(s, copy.deepcopy(self.__getstate__(), memo))
return s
def _setup_observers(self):
"""
Setup the default observers
1: pass through to parent, if present
"""
if self.has_parent():
self.add_observer(self._parent_, self._parent_._pass_through_notify_observers, -np.inf)
#===========================================================================
# Printing -> done
#===========================================================================

View file

@ -18,7 +18,7 @@ import numpy as np
import re
import logging
__updated__ = '2014-05-21'
__updated__ = '2014-09-22'
class HierarchyError(Exception):
"""
@ -63,7 +63,7 @@ class Observable(object):
@updates.setter
def updates(self, ups):
raise DeprecationWarning("updates is now a function, see update(True|False|None)")
def update_model(self, updates=None):
"""
Get or set, whether automatic updates are performed. When updates are
@ -87,21 +87,23 @@ class Observable(object):
else:
self._updates = updates
self.trigger_update()
def toggle_update(self):
self.update_model(not self.update())
def trigger_update(self):
def trigger_update(self, trigger_parent=True):
"""
Update the model from the current state.
Make sure that updates are on, otherwise this
method will do nothing
:param bool trigger_parent: Whether to trigger the parent, after self has updated
"""
if not self.update_model():
#print "Warning: updates are off, updating the model will do nothing"
return
self._trigger_params_changed()
self._trigger_params_changed(trigger_parent)
def add_observer(self, observer, callble, priority=0):
"""
Add an observer `observer` with the callback `callble`
@ -539,18 +541,18 @@ class Indexable(Nameable, Observable):
[np.put(ret, ind, p.lnpdf_grad(x[ind])) for p, ind in self.priors.iteritems()]
return ret
return 0.
#===========================================================================
# Tie parameters together
#===========================================================================
def _has_ties(self):
if self._highest_parent_.tie.tied_param is None:
return False
if self.has_parent():
return self._highest_parent_.tie.label_buf[self._highest_parent_._raveled_index_for(self)].sum()>0
return True
def tie_together(self):
self._highest_parent_.tie.add_tied_parameter(self)
self._highest_parent_._set_fixed(self,self._raveled_index())
@ -740,7 +742,7 @@ class OptimizationHandlable(Indexable):
self.param_array.flat[f] = p
[np.put(self.param_array, ind[f[ind]], c.f(self.param_array.flat[ind[f[ind]]]))
for c, ind in self.constraints.iteritems() if c != __fixed__]
self._highest_parent_.tie.propagate_val()
#self._highest_parent_.tie.propagate_val()
self._optimizer_copy_transformed = False
self._trigger_params_changed()
@ -826,6 +828,7 @@ class OptimizationHandlable(Indexable):
"""
# first take care of all parameters (from N(0,1))
x = rand_gen(size=self._size_transformed(), *args, **kwargs)
updates = self.update_model()
self.update_model(False) # Switch off the updates
self.optimizer_array = x # makes sure all of the tied parameters get the same init (since there's only one prior object...)
# now draw from prior where possible
@ -833,8 +836,8 @@ class OptimizationHandlable(Indexable):
[np.put(x, ind, p.rvs(ind.size)) for p, ind in self.priors.iteritems() if not p is None]
unfixlist = np.ones((self.size,),dtype=np.bool)
unfixlist[self.constraints[__fixed__]] = False
self.param_array[unfixlist] = x[unfixlist]
self.update_model(True)
self.param_array.flat[unfixlist] = x.view(np.ndarray).ravel()[unfixlist]
self.update_model(updates)
#===========================================================================
# For shared memory arrays. This does nothing in Param, but sets the memory
@ -928,7 +931,7 @@ class Parameterizable(OptimizationHandlable):
"""
if self.__dict__.get('_param_array_', None) is None:
self._param_array_ = np.empty(self.size, dtype=np.float64)
if self.constraints[__fixed__].size !=0:
fixes = np.ones(self.size).astype(bool)
fixes[self.constraints[__fixed__]] = FIXED