object without args

This commit is contained in:
Max Zwiessele 2014-03-13 11:51:54 +00:00
parent e9c96632ba
commit b95cc90ffb
2 changed files with 70 additions and 70 deletions

View file

@ -94,15 +94,15 @@ class Param(OptimizationHandlable, ObservableArray):
@property
def _param_array_(self):
return self
@property
def gradient(self):
return self._gradient_array_[self._current_slice_]
@gradient.setter
def gradient(self, val):
self.gradient[:] = val
#===========================================================================
# Pickling operations
#===========================================================================
@ -135,7 +135,7 @@ class Param(OptimizationHandlable, ObservableArray):
self._parent_index_ = state.pop()
self._parent_ = state.pop()
self.name = state.pop()
def copy(self, *args):
constr = self.constraints.copy()
priors = self.priors.copy()
@ -151,13 +151,13 @@ class Param(OptimizationHandlable, ObservableArray):
# if trigger_parent: min_priority = None
# else: min_priority = -numpy.inf
# self.notify_observers(None, min_priority)
#
#
# def _get_params(self):
# return self.flat
#
#
# def _collect_gradient(self, target):
# target += self.gradient.flat
#
#
# def _set_gradient(self, g):
# self.gradient = g.reshape(self._realshape_)
@ -173,10 +173,10 @@ class Param(OptimizationHandlable, ObservableArray):
try: new_arr._current_slice_ = s; new_arr._original_ = self.base is new_arr.base
except AttributeError: pass # returning 0d array or float, double etc
return new_arr
def __setitem__(self, s, val):
super(Param, self).__setitem__(s, val)
#===========================================================================
# Index Operations:
#===========================================================================
@ -195,7 +195,7 @@ class Param(OptimizationHandlable, ObservableArray):
a = self._realshape_[i] + a
internal_offset += a * extended_realshape[i]
return internal_offset
def _raveled_index(self, slice_index=None):
# return an index array on the raveled array, which is formed by the current_slice
# of this object
@ -203,7 +203,7 @@ class Param(OptimizationHandlable, ObservableArray):
ind = self._indices(slice_index)
if ind.ndim < 2: ind = ind[:, None]
return numpy.asarray(numpy.apply_along_axis(lambda x: numpy.sum(extended_realshape * x), 1, ind), dtype=int)
def _expand_index(self, slice_index=None):
# this calculates the full indexing arrays from the slicing objects given by get_item for _real..._ attributes
# it basically translates slices to their respective index arrays and turns negative indices around
@ -245,7 +245,7 @@ class Param(OptimizationHandlable, ObservableArray):
#===========================================================================
@property
def _description_str(self):
if self.size <= 1:
if self.size <= 1:
return [str(self.view(numpy.ndarray)[0])]
else: return [str(self.shape)]
def parameter_names(self, add_self=False, adjust_for_printing=False):
@ -356,7 +356,7 @@ class ParamConcatenation(object):
self._param_sizes = [p.size for p in self.params]
startstops = numpy.cumsum([0] + self._param_sizes)
self._param_slices_ = [slice(start, stop) for start,stop in zip(startstops, startstops[1:])]
parents = dict()
for p in self.params:
if p.has_parent():
@ -396,7 +396,7 @@ class ParamConcatenation(object):
def update_all_params(self):
for par in self.parents:
par.notify_observers(-numpy.inf)
def constrain(self, constraint, warning=True):
[param.constrain(constraint, trigger_parent=False) for param in self.params]
self.update_all_params()