all parameters in memory

This commit is contained in:
Max Zwiessele 2014-03-10 08:21:13 +00:00
parent f7223ea377
commit 546d5dfff3
9 changed files with 135 additions and 103 deletions

View file

@ -65,8 +65,8 @@ class Parameterized(Parameterizable, Pickleable):
# **Never** call parameters_changed() yourself
__metaclass__ = ParametersChangedMeta
#===========================================================================
def __init__(self, name=None, *a, **kw):
super(Parameterized, self).__init__(name=name, parent=None, parent_index=None, *a, **kw)
def __init__(self, name=None, parameters=[], *a, **kw):
super(Parameterized, self).__init__(name=name, *a, **kw)
self._in_init_ = True
self._parameters_ = ArrayList()
self.size = sum(p.size for p in self._parameters_)
@ -76,6 +76,7 @@ class Parameterized(Parameterizable, Pickleable):
self._param_slices_ = []
self._connect_parameters()
del self._in_init_
self.add_parameters(*parameters)
def build_pydot(self, G=None):
import pydot # @UnresolvedImport
@ -205,25 +206,29 @@ class Parameterized(Parameterizable, Pickleable):
return found_params
def __getitem__(self, name, paramlist=None):
if paramlist is None:
paramlist = self.grep_param_names(name)
if len(paramlist) < 1: raise AttributeError, name
if len(paramlist) == 1:
if isinstance(paramlist[-1], Parameterized):
paramlist = paramlist[-1].flattened_parameters
if len(paramlist) != 1:
return ParamConcatenation(paramlist)
return paramlist[-1]
return ParamConcatenation(paramlist)
if isinstance(name, (int, slice, tuple, np.ndarray)):
return self._param_array_[name]
else:
if paramlist is None:
paramlist = self.grep_param_names(name)
if len(paramlist) < 1: raise AttributeError, name
if len(paramlist) == 1:
if isinstance(paramlist[-1], Parameterized):
paramlist = paramlist[-1].flattened_parameters
if len(paramlist) != 1:
return ParamConcatenation(paramlist)
return paramlist[-1]
return ParamConcatenation(paramlist)
def __setitem__(self, name, value, paramlist=None):
if isinstance(name, (slice, tuple, np.ndarray)):
self._param_array_[name] = value
self.notify_observers()
else:
try: param = self.__getitem__(name, paramlist)
except AttributeError as a: raise a
param[:] = value
def __setattr__(self, name, val):
# override the default behaviour, if setting a param, so broadcasting can by used
if hasattr(self, '_parameters_'):