mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 12:32:40 +02:00
gradient checker more robust against name changes
This commit is contained in:
parent
bb05c6a02f
commit
76840b6e6e
1 changed files with 7 additions and 4 deletions
|
|
@ -76,9 +76,9 @@ class GradientChecker(Model):
|
||||||
self.shapes = [get_shape(x0)]
|
self.shapes = [get_shape(x0)]
|
||||||
for name, xi in zip(self.names, at_least_one_element(x0)):
|
for name, xi in zip(self.names, at_least_one_element(x0)):
|
||||||
self.__setattr__(name, xi)
|
self.__setattr__(name, xi)
|
||||||
self._param_names = []
|
# self._param_names = []
|
||||||
for name, shape in zip(self.names, self.shapes):
|
# for name, shape in zip(self.names, self.shapes):
|
||||||
self._param_names.extend(map(lambda nameshape: ('_'.join(nameshape)).strip('_'), itertools.izip(itertools.repeat(name), itertools.imap(lambda t: '_'.join(map(str, t)), itertools.product(*map(lambda xi: range(xi), shape))))))
|
# self._param_names.extend(map(lambda nameshape: ('_'.join(nameshape)).strip('_'), itertools.izip(itertools.repeat(name), itertools.imap(lambda t: '_'.join(map(str, t)), itertools.product(*map(lambda xi: range(xi), shape))))))
|
||||||
self.args = args
|
self.args = args
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
self.f = f
|
self.f = f
|
||||||
|
|
@ -108,4 +108,7 @@ class GradientChecker(Model):
|
||||||
current_index += current_size
|
current_index += current_size
|
||||||
|
|
||||||
def _get_param_names(self):
|
def _get_param_names(self):
|
||||||
return self._param_names
|
_param_names = []
|
||||||
|
for name, shape in zip(self.names, self.shapes):
|
||||||
|
_param_names.extend(map(lambda nameshape: ('_'.join(nameshape)).strip('_'), itertools.izip(itertools.repeat(name), itertools.imap(lambda t: '_'.join(map(str, t)), itertools.product(*map(lambda xi: range(xi), shape))))))
|
||||||
|
return _param_names
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue