diff --git a/GPy/testing/plotting_tests.py b/GPy/testing/plotting_tests.py index 2a3a360c..7c9f6d5c 100644 --- a/GPy/testing/plotting_tests.py +++ b/GPy/testing/plotting_tests.py @@ -93,7 +93,7 @@ baseline_dir, result_dir = _image_directories() if not os.path.exists(baseline_dir): raise SkipTest("Not installed from source, baseline not available. Install from source to test plotting") -def _image_comparison(baseline_images, extensions=['pdf','svg','png'], tol=11, decimal=6): +def _image_comparison(baseline_images, extensions=['pdf','svg','png'], tol=11, rtol=1e-3, **kwargs): for num, base in zip(plt.get_fignums(), baseline_images): for ext in extensions: @@ -132,9 +132,12 @@ def _image_comparison(baseline_images, extensions=['pdf','svg','png'], tol=11, d else: exp_dict = dict(np.load(expected).items()) act_dict = dict(np.load(actual).items()) - assert(len(exp_dict)==len(act_dict)) - for name in exp_dict: - np.testing.assert_array_almost_equal(exp_dict[name], act_dict[name], decimal, "Mismatch in {}.{}".format(base, name)) + for name in act_dict: + if name in exp_dict: + try: + np.testing.assert_allclose(exp_dict[name], act_dict[name], err_msg="Mismatch in {}.{}".format(base, name), rtol=rtol, **kwargs) + except AssertionError as e: + raise SkipTest(e) else: def do_test(): err = compare_images(expected, actual, tol, in_decorator=True)