gradient checker more robust against name changes

This commit is contained in:
Max Zwiessele 2013-07-29 15:26:49 +01:00
parent bb05c6a02f
commit 76840b6e6e

View file

@ -76,9 +76,9 @@ class GradientChecker(Model):
self.shapes = [get_shape(x0)] self.shapes = [get_shape(x0)]
for name, xi in zip(self.names, at_least_one_element(x0)): for name, xi in zip(self.names, at_least_one_element(x0)):
self.__setattr__(name, xi) self.__setattr__(name, xi)
self._param_names = [] # self._param_names = []
for name, shape in zip(self.names, self.shapes): # for name, shape in zip(self.names, self.shapes):
self._param_names.extend(map(lambda nameshape: ('_'.join(nameshape)).strip('_'), itertools.izip(itertools.repeat(name), itertools.imap(lambda t: '_'.join(map(str, t)), itertools.product(*map(lambda xi: range(xi), shape)))))) # self._param_names.extend(map(lambda nameshape: ('_'.join(nameshape)).strip('_'), itertools.izip(itertools.repeat(name), itertools.imap(lambda t: '_'.join(map(str, t)), itertools.product(*map(lambda xi: range(xi), shape))))))
self.args = args self.args = args
self.kwargs = kwargs self.kwargs = kwargs
self.f = f self.f = f
@ -108,4 +108,7 @@ class GradientChecker(Model):
current_index += current_size current_index += current_size
def _get_param_names(self): def _get_param_names(self):
return self._param_names _param_names = []
for name, shape in zip(self.names, self.shapes):
_param_names.extend(map(lambda nameshape: ('_'.join(nameshape)).strip('_'), itertools.izip(itertools.repeat(name), itertools.imap(lambda t: '_'.join(map(str, t)), itertools.product(*map(lambda xi: range(xi), shape))))))
return _param_names