mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-02 14:45:15 +02:00
Modified gradient_checker to allow for variable 'f'
This commit is contained in:
parent
1dd83291fe
commit
64e65b846d
1 changed files with 15 additions and 15 deletions
|
|
@ -81,8 +81,8 @@ class GradientChecker(Model):
|
||||||
# self._param_names.extend(map(lambda nameshape: ('_'.join(nameshape)).strip('_'), itertools.izip(itertools.repeat(name), itertools.imap(lambda t: '_'.join(map(str, t)), itertools.product(*map(lambda xi: range(xi), shape))))))
|
# self._param_names.extend(map(lambda nameshape: ('_'.join(nameshape)).strip('_'), itertools.izip(itertools.repeat(name), itertools.imap(lambda t: '_'.join(map(str, t)), itertools.product(*map(lambda xi: range(xi), shape))))))
|
||||||
self.args = args
|
self.args = args
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
self.f = f
|
self._f = f
|
||||||
self.df = df
|
self._df = df
|
||||||
|
|
||||||
def _get_x(self):
|
def _get_x(self):
|
||||||
if len(self.names) > 1:
|
if len(self.names) > 1:
|
||||||
|
|
@ -90,10 +90,10 @@ class GradientChecker(Model):
|
||||||
return [self.__getattribute__(self.names[0])] + list(self.args)
|
return [self.__getattribute__(self.names[0])] + list(self.args)
|
||||||
|
|
||||||
def log_likelihood(self):
|
def log_likelihood(self):
|
||||||
return float(numpy.sum(self.f(*self._get_x(), **self.kwargs)))
|
return float(numpy.sum(self._f(*self._get_x(), **self.kwargs)))
|
||||||
|
|
||||||
def _log_likelihood_gradients(self):
|
def _log_likelihood_gradients(self):
|
||||||
return numpy.atleast_1d(self.df(*self._get_x(), **self.kwargs)).flatten()
|
return numpy.atleast_1d(self._df(*self._get_x(), **self.kwargs)).flatten()
|
||||||
|
|
||||||
|
|
||||||
def _get_params(self):
|
def _get_params(self):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue