[kernel] plot_ard added (some other fixes as well)

This commit is contained in:
Max Zwiessele 2015-10-09 16:07:57 +01:00
parent e3617942d4
commit d2d8a62d2d
14 changed files with 371 additions and 337 deletions

View file

@ -1,21 +1,21 @@
#===============================================================================
# Copyright (c) 2015, Max Zwiessele
# All rights reserved.
#
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
#
# * Neither the name of GPy nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
@ -32,7 +32,9 @@ import GPy, os
from nose import SkipTest
from ..util.config import config
from ..plotting import change_plotting_library
change_plotting_library('matplotlib')
if config.get('plotting', 'library') != 'matplotlib':
raise SkipTest("Matplotlib not installed, not testing plots")
@ -64,7 +66,7 @@ def _image_directories():
def _sequenceEqual(a, b):
assert len(a) == len(b), "Sequences not same length"
for i, [x, y], in enumerate(zip(a, b)):
assert x == y, "element not matching {}".format(i)
assert x == y, "element not matching {}".format(i)
def _notFound(path):
raise IOError('File {} not in baseline')
@ -89,7 +91,17 @@ def _image_comparison(baseline_images, extensions=['pdf','svg','ong'], tol=11):
raise ImageComparisonFailure("Error between {} and {} is {:.5f}, which is bigger then the tolerance of {:.5f}".format(actual, expected, err['rms'], tol))
yield do_test
plt.close('all')
def test_kernel():
np.random.seed(1239847)
k = GPy.kern.RBF(5, ARD=True) + GPy.kern.Linear(3, active_dims=[0,2,4], ARD=True) + GPy.kern.Bias(2)
k.randomize()
k.plot_ARD(legend=True)
for do_test in _image_comparison(
baseline_images=['kern_{}'.format(sub) for sub in ["ARD",]],
extensions=extensions):
yield (do_test, )
def test_plot():
np.random.seed(11111)
X = np.random.uniform(-2, 2, (40, 1))
@ -162,7 +174,7 @@ def test_classification():
for do_test in _image_comparison(baseline_images=['gp_class_{}'.format(sub) for sub in ["likelihood", "raw", 'raw_link']], extensions=extensions):
yield (do_test, )
def test_sparse_classification():
np.random.seed(11111)
X = np.random.uniform(-2, 2, (40, 1))
@ -218,7 +230,7 @@ def test_bayesian_gplvm():
m.plot_steepest_gradient_map(resolution=7)
for do_test in _image_comparison(baseline_images=['bayesian_gplvm_{}'.format(sub) for sub in ["inducing", "inducing_3d", "latent_3d", "magnification", 'gradient']], extensions=extensions):
yield (do_test, )
if __name__ == '__main__':
import nose
nose.main()