Dont call parameters_changed ever yourself anymore and parameters are now inplace once in memory

This commit is contained in:
Max Zwiessele 2014-03-04 17:32:46 +00:00
parent 56d749ded8
commit 0df263956f
21 changed files with 601 additions and 284 deletions

View file

@ -54,7 +54,7 @@ class Param(OptimizationHandlable, ObservableArray, Gradcheckable):
obj._tied_to_me_ = SetDict()
obj._tied_to_ = []
obj._original_ = True
obj._gradient_ = None
obj._gradient_array_ = numpy.zeros(obj.shape, dtype=numpy.float64)
return obj
def __init__(self, name, input_array, default_constraint=None, *a, **kw):
@ -77,7 +77,7 @@ class Param(OptimizationHandlable, ObservableArray, Gradcheckable):
# see InfoArray.__array_finalize__ for comments
if obj is None: return
super(Param, self).__array_finalize__(obj)
self._direct_parent_ = getattr(obj, '_direct_parent_', None)
self._parent_ = getattr(obj, '_parent_', None)
self._parent_index_ = getattr(obj, '_parent_index_', None)
self._default_constraint_ = getattr(obj, '_default_constraint_', None)
self._current_slice_ = getattr(obj, '_current_slice_', None)
@ -89,16 +89,18 @@ class Param(OptimizationHandlable, ObservableArray, Gradcheckable):
self._updated_ = getattr(obj, '_updated_', None)
self._original_ = getattr(obj, '_original_', None)
self._name = getattr(obj, 'name', None)
self._gradient_ = getattr(obj, '_gradient_', None)
self._gradient_array_ = getattr(obj, '_gradient_array_', None)
self.constraints = getattr(obj, 'constraints', None)
self.priors = getattr(obj, 'priors', None)
@property
def _param_array_(self):
return self
@property
def gradient(self):
if self._gradient_ is None:
self._gradient_ = numpy.zeros(self._realshape_)
return self._gradient_[self._current_slice_]
return self._gradient_array_[self._current_slice_]
@gradient.setter
def gradient(self, val):
self.gradient[:] = val
@ -110,7 +112,7 @@ class Param(OptimizationHandlable, ObservableArray, Gradcheckable):
func, args, state = super(Param, self).__reduce__()
return func, args, (state,
(self.name,
self._direct_parent_,
self._parent_,
self._parent_index_,
self._default_constraint_,
self._current_slice_,
@ -135,7 +137,7 @@ class Param(OptimizationHandlable, ObservableArray, Gradcheckable):
self._current_slice_ = state.pop()
self._default_constraint_ = state.pop()
self._parent_index_ = state.pop()
self._direct_parent_ = state.pop()
self._parent_ = state.pop()
self.name = state.pop()
def copy(self, *args):
@ -148,20 +150,20 @@ class Param(OptimizationHandlable, ObservableArray, Gradcheckable):
#===========================================================================
# get/set parameters
#===========================================================================
def _set_params(self, param, trigger_parent=True):
self.flat = param
if trigger_parent: min_priority = None
else: min_priority = -numpy.inf
self._notify_observers(None, min_priority)
def _get_params(self):
return self.flat
def _collect_gradient(self, target):
target += self.gradient.flat
def _set_gradient(self, g):
self.gradient = g.reshape(self._realshape_)
# def _set_params(self, param, trigger_parent=True):
# self.flat = param
# if trigger_parent: min_priority = None
# else: min_priority = -numpy.inf
# self.notify_observers(None, min_priority)
#
# def _get_params(self):
# return self.flat
#
# def _collect_gradient(self, target):
# target += self.gradient.flat
#
# def _set_gradient(self, g):
# self.gradient = g.reshape(self._realshape_)
#===========================================================================
# Array operations -> done
@ -362,7 +364,7 @@ class ParamConcatenation(object):
parents = dict()
for p in self.params:
if p.has_parent():
parent = p._direct_parent_
parent = p._parent_
level = 0
while parent is not None:
if parent in parents:
@ -370,7 +372,7 @@ class ParamConcatenation(object):
else:
parents[parent] = level
level += 1
parent = parent._direct_parent_
parent = parent._parent_
import operator
self.parents = map(lambda x: x[0], sorted(parents.iteritems(), key=operator.itemgetter(1)))
#===========================================================================
@ -397,7 +399,7 @@ class ParamConcatenation(object):
#===========================================================================
def update_all_params(self):
for par in self.parents:
par._notify_observers(-numpy.inf)
par.notify_observers(-numpy.inf)
def constrain(self, constraint, warning=True):
[param.constrain(constraint, trigger_parent=False) for param in self.params]