new gradient handling way nicer

This commit is contained in:
Max Zwiessele 2014-01-24 15:07:28 +00:00
parent f0ac290eb3
commit e128059377
6 changed files with 15 additions and 26 deletions

View file

@ -34,9 +34,9 @@ class Model(Parameterized):
g = np.zeros(self.size)
try:
#[g.__setitem__(s, self.gradient_mapping[p]().flat) for p, s in itertools.izip(self._parameters_, self._param_slices_) if not p.is_fixed]
[g.__setitem__(s, p.gradient.flat) for p, s in itertools.izip(self._parameters_, self._param_slices_) if not p.is_fixed]
except KeyError:
raise KeyError, 'Gradient for {} not defined, please specify gradients for parameters to optimize'.format(p.name)
[p._collect_gradient(g[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_) if not p.is_fixed]
except ValueError:
raise ValueError, 'Gradient for {} not defined, please specify gradients for parameters to optimize'.format(p.name)
return g
raise NotImplementedError, "this needs to be implemented to use the model class"