diff --git a/GPy/util/plot_latent.py b/GPy/util/plot_latent.py index dd21b2ea..81c3d6fc 100644 --- a/GPy/util/plot_latent.py +++ b/GPy/util/plot_latent.py @@ -46,8 +46,9 @@ def plot_latent(model, labels=None, which_indices=None, Xtest_full[:, [input_1, input_2]] = x mu, var, low, up = model.predict(Xtest_full) var = var[:, :1] - return var - view = ImshowController(ax, plot_function, tuple(xmin) + tuple(xmax), + return np.log(var) + view = ImshowController(ax, plot_function, + tuple(model.X.min(0)[:, [input_1, input_2]]) + tuple(model.X.max(0)[:, [input_1, input_2]]), resolution, aspect=aspect, interpolation='bilinear', cmap=pb.cm.binary) @@ -124,7 +125,8 @@ def plot_magnification(model, labels=None, which_indices=None, Xtest_full[:, [input_1, input_2]] = x mf=model.magnification(Xtest_full) return mf - view = ImshowController(ax, plot_function, tuple(xmin) + tuple(xmax), + view = ImshowController(ax, plot_function, + tuple(model.X.min(0)[:, [input_1, input_2]]) + tuple(model.X.max(0)[:, [input_1, input_2]]), resolution, aspect=aspect, interpolation='bilinear', cmap=pb.cm.gray)