mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 04:22:38 +02:00
gradient checker more robust against name changes
This commit is contained in:
parent
bb05c6a02f
commit
76840b6e6e
1 changed files with 7 additions and 4 deletions
|
|
@ -76,9 +76,9 @@ class GradientChecker(Model):
|
|||
self.shapes = [get_shape(x0)]
|
||||
for name, xi in zip(self.names, at_least_one_element(x0)):
|
||||
self.__setattr__(name, xi)
|
||||
self._param_names = []
|
||||
for name, shape in zip(self.names, self.shapes):
|
||||
self._param_names.extend(map(lambda nameshape: ('_'.join(nameshape)).strip('_'), itertools.izip(itertools.repeat(name), itertools.imap(lambda t: '_'.join(map(str, t)), itertools.product(*map(lambda xi: range(xi), shape))))))
|
||||
# self._param_names = []
|
||||
# for name, shape in zip(self.names, self.shapes):
|
||||
# self._param_names.extend(map(lambda nameshape: ('_'.join(nameshape)).strip('_'), itertools.izip(itertools.repeat(name), itertools.imap(lambda t: '_'.join(map(str, t)), itertools.product(*map(lambda xi: range(xi), shape))))))
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.f = f
|
||||
|
|
@ -108,4 +108,7 @@ class GradientChecker(Model):
|
|||
current_index += current_size
|
||||
|
||||
def _get_param_names(self):
|
||||
return self._param_names
|
||||
_param_names = []
|
||||
for name, shape in zip(self.names, self.shapes):
|
||||
_param_names.extend(map(lambda nameshape: ('_'.join(nameshape)).strip('_'), itertools.izip(itertools.repeat(name), itertools.imap(lambda t: '_'.join(map(str, t)), itertools.product(*map(lambda xi: range(xi), shape))))))
|
||||
return _param_names
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue