mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 20:42:39 +02:00
gradient checker implemented
This commit is contained in:
parent
33916b4d58
commit
d5cb531d40
2 changed files with 28 additions and 7 deletions
|
|
@ -12,3 +12,4 @@ from sparse_gplvm import SparseGPLVM
|
||||||
from warped_gp import WarpedGP
|
from warped_gp import WarpedGP
|
||||||
from bayesian_gplvm import BayesianGPLVM
|
from bayesian_gplvm import BayesianGPLVM
|
||||||
from mrd import MRD
|
from mrd import MRD
|
||||||
|
from gradient_checker import GradientChecker
|
||||||
|
|
|
||||||
|
|
@ -37,10 +37,29 @@ class GradientChecker(Model):
|
||||||
Names to print, when performing gradcheck. If a list was passed to x0
|
Names to print, when performing gradcheck. If a list was passed to x0
|
||||||
a list of names with the same length is expected.
|
a list of names with the same length is expected.
|
||||||
:param args: Arguments passed as f(x, *args, **kwargs) and df(x, *args, **kwargs)
|
:param args: Arguments passed as f(x, *args, **kwargs) and df(x, *args, **kwargs)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
---------
|
||||||
|
|
||||||
|
Sinusoid:
|
||||||
|
|
||||||
|
X = numpy.random.rand(N, Q)
|
||||||
|
f = lambda x: numpy.sin(x)
|
||||||
|
df = lambda x: numpy.cos(x)
|
||||||
|
grad = gc.GradientChecker(f,df,X,'x')
|
||||||
|
|
||||||
|
Using GPy:
|
||||||
|
|
||||||
|
N, M, Q = 10, 5, 3
|
||||||
|
X, Z = numpy.random.randn(N,Q), numpy.random.randn(M,Q)
|
||||||
|
kern = GPy.kern.linear(Q, ARD=True) + GPy.kern.rbf(Q, ARD=True)
|
||||||
|
import GPy.models.gradient_checker as gc
|
||||||
|
grad = gc.GradientChecker(kern.K,
|
||||||
|
lambda x: 2*kern.dK_dX(numpy.ones((1,1)), x),
|
||||||
|
x0 = X.copy(),
|
||||||
|
names='X')
|
||||||
"""
|
"""
|
||||||
Model.__init__(self)
|
Model.__init__(self)
|
||||||
self.f = f
|
|
||||||
self.df = df
|
|
||||||
if isinstance(x0, (list, tuple)) and names is None:
|
if isinstance(x0, (list, tuple)) and names is None:
|
||||||
self.shapes = [get_shape(xi) for xi in x0]
|
self.shapes = [get_shape(xi) for xi in x0]
|
||||||
self.names = ['X{i}'.format(i=i) for i in range(len(x0))]
|
self.names = ['X{i}'.format(i=i) for i in range(len(x0))]
|
||||||
|
|
@ -60,17 +79,19 @@ class GradientChecker(Model):
|
||||||
self._param_names.extend(map(lambda nameshape: ('_'.join(nameshape)).strip('_'), itertools.izip(itertools.repeat(name), itertools.imap(lambda t: '_'.join(map(str, t)), itertools.product(*map(lambda xi: range(xi), shape))))))
|
self._param_names.extend(map(lambda nameshape: ('_'.join(nameshape)).strip('_'), itertools.izip(itertools.repeat(name), itertools.imap(lambda t: '_'.join(map(str, t)), itertools.product(*map(lambda xi: range(xi), shape))))))
|
||||||
self.args = args
|
self.args = args
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
self.f = f
|
||||||
|
self.df = df
|
||||||
|
|
||||||
def _get_x(self):
|
def _get_x(self):
|
||||||
if len(self.names) > 1:
|
if len(self.names) > 1:
|
||||||
return [self.__getattribute__(name) for name in self.names]
|
return [self.__getattribute__(name) for name in self.names] + list(self.args)
|
||||||
return self.__getattribute__(self.names[0])
|
return [self.__getattribute__(self.names[0])] + list(self.args)
|
||||||
|
|
||||||
def log_likelihood(self):
|
def log_likelihood(self):
|
||||||
return numpy.atleast_1d(self.f(self._get_x(), *self.args, **self.kwargs))
|
return float(numpy.sum(self.f(*self._get_x(), **self.kwargs)))
|
||||||
|
|
||||||
def _log_likelihood_gradients(self):
|
def _log_likelihood_gradients(self):
|
||||||
return numpy.atleast_1d(self.df(self._get_x(), *self.args, **self.kwargs))
|
return numpy.atleast_1d(self.df(*self._get_x(), **self.kwargs)).flatten()
|
||||||
|
|
||||||
|
|
||||||
def _get_params(self):
|
def _get_params(self):
|
||||||
|
|
@ -86,4 +107,3 @@ class GradientChecker(Model):
|
||||||
|
|
||||||
def _get_param_names(self):
|
def _get_param_names(self):
|
||||||
return self._param_names
|
return self._param_names
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue