mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-04-25 12:56:22 +02:00
[plotting] tests now compare the arrays of the figure, instead of the platform dependend png images
This commit is contained in:
parent
2f901a90e2
commit
108ae55fbc
78 changed files with 97 additions and 30 deletions
|
|
@ -72,7 +72,7 @@ try:
|
|||
except ImportError:
|
||||
raise SkipTest("Matplotlib not installed, not testing plots")
|
||||
|
||||
extensions = ['png']
|
||||
extensions = ['npz']
|
||||
|
||||
def _image_directories():
|
||||
"""
|
||||
|
|
@ -93,39 +93,104 @@ baseline_dir, result_dir = _image_directories()
|
|||
if not os.path.exists(baseline_dir):
|
||||
raise SkipTest("Not installed from source, baseline not available. Install from source to test plotting")
|
||||
|
||||
def _sequenceEqual(a, b):
|
||||
assert len(a) == len(b), "Sequences not same length"
|
||||
for i, [x, y], in enumerate(zip(a, b)):
|
||||
assert x == y, "element not matching {}".format(i)
|
||||
def _image_comparison(baseline_images, extensions=['pdf','svg','png'], tol=11, decimal=6):
|
||||
|
||||
def _notFound(path):
|
||||
raise IOError('File {} not in baseline')
|
||||
|
||||
def _image_comparison(baseline_images, extensions=['pdf','svg','png'], tol=11):
|
||||
for num, base in zip(plt.get_fignums(), baseline_images):
|
||||
for ext in extensions:
|
||||
fig = plt.figure(num)
|
||||
fig.canvas.draw()
|
||||
#fig.axes[0].set_axis_off()
|
||||
#fig.set_frameon(False)
|
||||
fig.canvas.draw()
|
||||
fig.savefig(os.path.join(result_dir, "{}.{}".format(base, ext)),
|
||||
transparent=True,
|
||||
edgecolor='none',
|
||||
facecolor='none',
|
||||
#bbox='tight'
|
||||
)
|
||||
if ext in ['npz']:
|
||||
figdict = flatten_axis(fig)
|
||||
np.savez_compressed(os.path.join(result_dir, "{}.{}".format(base, ext)), **figdict)
|
||||
fig.savefig(os.path.join(result_dir, "{}.{}".format(base, 'png')),
|
||||
transparent=True,
|
||||
edgecolor='none',
|
||||
facecolor='none',
|
||||
#bbox='tight'
|
||||
)
|
||||
else:
|
||||
fig.savefig(os.path.join(result_dir, "{}.{}".format(base, ext)),
|
||||
transparent=True,
|
||||
edgecolor='none',
|
||||
facecolor='none',
|
||||
#bbox='tight'
|
||||
)
|
||||
for num, base in zip(plt.get_fignums(), baseline_images):
|
||||
for ext in extensions:
|
||||
#plt.close(num)
|
||||
actual = os.path.join(result_dir, "{}.{}".format(base, ext))
|
||||
expected = os.path.join(baseline_dir, "{}.{}".format(base, ext))
|
||||
def do_test():
|
||||
err = compare_images(expected, actual, tol, in_decorator=True)
|
||||
if err:
|
||||
raise SkipTest("Error between {} and {} is {:.5f}, which is bigger then the tolerance of {:.5f}".format(actual, expected, err['rms'], tol))
|
||||
if ext == 'npz':
|
||||
def do_test():
|
||||
if not os.path.exists(expected):
|
||||
import shutil
|
||||
shutil.copy2(actual, expected)
|
||||
#shutil.copy2(os.path.join(result_dir, "{}.{}".format(base, 'png')), os.path.join(baseline_dir, "{}.{}".format(base, 'png')))
|
||||
raise IOError("Baseline file {} not found, copying result {}".format(expected, actual))
|
||||
else:
|
||||
exp_dict = dict(np.load(expected).items())
|
||||
act_dict = dict(np.load(actual).items())
|
||||
assert(len(exp_dict)==len(act_dict))
|
||||
for name in exp_dict:
|
||||
np.testing.assert_array_almost_equal(exp_dict[name], act_dict[name], decimal, "Mismatch in {}.{}".format(base, name))
|
||||
else:
|
||||
def do_test():
|
||||
err = compare_images(expected, actual, tol, in_decorator=True)
|
||||
if err:
|
||||
raise SkipTest("Error between {} and {} is {:.5f}, which is bigger then the tolerance of {:.5f}".format(actual, expected, err['rms'], tol))
|
||||
yield do_test
|
||||
plt.close('all')
|
||||
|
||||
def flatten_axis(ax, prevname=''):
|
||||
import inspect
|
||||
members = inspect.getmembers(ax)
|
||||
|
||||
arrays = {}
|
||||
|
||||
def _flatten(l, pre):
|
||||
arr = {}
|
||||
if isinstance(l, np.ndarray):
|
||||
if l.size:
|
||||
arr[pre] = np.asarray(l)
|
||||
elif isinstance(l, dict):
|
||||
for _n in l:
|
||||
_tmp = _flatten(l, pre+"."+_n+".")
|
||||
for _nt in _tmp.keys():
|
||||
arrays[_nt] = _tmp[_nt]
|
||||
elif isinstance(l, list) and len(l)>0:
|
||||
for i in range(len(l)):
|
||||
_tmp = _flatten(l[i], pre+"[{}]".format(i))
|
||||
for _n in _tmp:
|
||||
arr["{}".format(_n)] = _tmp[_n]
|
||||
else:
|
||||
return flatten_axis(l, pre+'.')
|
||||
return arr
|
||||
|
||||
|
||||
for name, l in members:
|
||||
if isinstance(l, np.ndarray):
|
||||
arrays[prevname+name] = np.asarray(l)
|
||||
elif isinstance(l, list) and len(l)>0:
|
||||
for i in range(len(l)):
|
||||
_tmp = _flatten(l[i], prevname+name+"[{}]".format(i))
|
||||
for _n in _tmp:
|
||||
arrays["{}".format(_n)] = _tmp[_n]
|
||||
|
||||
return arrays
|
||||
|
||||
def _a(x,y,decimal):
|
||||
np.testing.assert_array_almost_equal(x, y, decimal)
|
||||
|
||||
def compare_axis_dicts(x, y, decimal=6):
|
||||
try:
|
||||
assert(len(x)==len(y))
|
||||
for name in x:
|
||||
_a(x[name], y[name], decimal)
|
||||
except AssertionError as e:
|
||||
raise SkipTest(e.message)
|
||||
|
||||
def test_figure():
|
||||
np.random.seed(1239847)
|
||||
from GPy.plotting import plotting_library as pl
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue