Caching functions now take two arguments: self and which, which is the argument which started the update

This commit is contained in:
Max Zwiessele 2014-03-24 08:52:23 +00:00
parent 0b5f6ea7c6
commit a493dd085e
7 changed files with 114 additions and 44 deletions

View file

@ -368,26 +368,26 @@ class ParamConcatenation(object):
#===========================================================================
def __getitem__(self, s):
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
params = [p._get_params()[ind[ps]] for p,ps in zip(self.params, self._param_slices_) if numpy.any(p._get_params()[ind[ps]])]
params = [p._param_array_[ind[ps]] for p,ps in zip(self.params, self._param_slices_) if numpy.any(p._param_array_[ind[ps]])]
if len(params)==1: return params[0]
return ParamConcatenation(params)
def __setitem__(self, s, val, update=True):
if isinstance(val, ParamConcatenation):
val = val._vals()
val = val.values()
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
vals = self._vals(); vals[s] = val; del val
vals = self.values(); vals[s] = val; del val
[numpy.place(p, ind[ps], vals[ps])
for p, ps in zip(self.params, self._param_slices_)]
if update:
self.update_all_params()
def _vals(self):
def values(self):
return numpy.hstack([p._param_array_ for p in self.params])
#===========================================================================
# parameter operations:
#===========================================================================
def update_all_params(self):
for par in self.parents:
par.notify_observers(-numpy.inf)
par.notify_observers()
def constrain(self, constraint, warning=True):
[param.constrain(constraint, trigger_parent=False) for param in self.params]
@ -442,12 +442,12 @@ class ParamConcatenation(object):
return self.params[0]._highest_parent_._checkgrad(self, verbose, step, tolerance)
#checkgrad.__doc__ = Gradcheckable.checkgrad.__doc__
__lt__ = lambda self, val: self._vals() < val
__le__ = lambda self, val: self._vals() <= val
__eq__ = lambda self, val: self._vals() == val
__ne__ = lambda self, val: self._vals() != val
__gt__ = lambda self, val: self._vals() > val
__ge__ = lambda self, val: self._vals() >= val
__lt__ = lambda self, val: self.values() < val
__le__ = lambda self, val: self.values() <= val
__eq__ = lambda self, val: self.values() == val
__ne__ = lambda self, val: self.values() != val
__gt__ = lambda self, val: self.values() > val
__ge__ = lambda self, val: self.values() >= val
def __str__(self, *args, **kwargs):
def f(p):
ind = p._raveled_index()