[magnification] plot_magnification expanded

This commit is contained in:
Max Zwiessele 2015-09-03 09:33:07 +01:00
parent ca60ad3195
commit 839e3dc6f0
8 changed files with 64 additions and 60 deletions

View file

@ -9,7 +9,8 @@ import itertools
try:
import Tango
from matplotlib.cm import get_cmap
import pylab as pb
from matplotlib import pyplot as pb
from matplotlib import cm
except:
pass
@ -137,7 +138,7 @@ def plot_latent(model, labels=None, which_indices=None,
view = ImshowController(ax, plot_function,
(xmin, ymin, xmax, ymax),
resolution, aspect=aspect, interpolation='bilinear',
cmap=pb.cm.binary, **imshow_kwargs)
cmap=cm.binary, **imshow_kwargs)
# make sure labels are in order of input:
labels = np.asarray(labels)
@ -192,18 +193,18 @@ def plot_latent(model, labels=None, which_indices=None,
if updates:
try:
ax.figure.canvas.show()
fig.canvas.show()
except Exception as e:
print("Could not invoke show: {}".format(e))
raw_input('Enter to continue')
view.deactivate()
#raw_input('Enter to continue')
return view
return ax
def plot_magnification(model, labels=None, which_indices=None,
resolution=60, ax=None, marker='o', s=40,
fignum=None, plot_inducing=False, legend=True,
plot_limits=None,
aspect='auto', updates=False):
aspect='auto', updates=False, mean=True, covariance=True, kern=None):
"""
:param labels: a np.array of size model.num_data containing labels for the points (can be number, strings, etc)
:param resolution: the resolution of the grid on which to evaluate the predictive variance
@ -211,6 +212,8 @@ def plot_magnification(model, labels=None, which_indices=None,
if ax is None:
fig = pb.figure(num=fignum)
ax = fig.add_subplot(111)
else:
fig = ax.figure
Tango.reset()
if labels is None:
@ -295,13 +298,13 @@ def plot_magnification(model, labels=None, which_indices=None,
def plot_function(x):
Xtest_full = np.zeros((x.shape[0], X.shape[1]))
Xtest_full[:, [input_1, input_2]] = x
mf = model.predict_magnification(Xtest_full)
mf = model.predict_magnification(Xtest_full, kern=kern, mean=mean, covariance=covariance)
return mf
view = ImshowController(ax, plot_function,
(xmin, ymin, xmax, ymax),
resolution, aspect=aspect, interpolation='bilinear',
cmap=pb.cm.gray)
cmap=cm.gray)
# make sure labels are in order of input:
ulabels = []
@ -317,7 +320,7 @@ def plot_magnification(model, labels=None, which_indices=None,
elif type(ul) is np.int64:
this_label = 'class %i' % ul
else:
this_label = 'class %i' % i
this_label = unicode(ul)
m = marker.next()
index = np.nonzero(labels == ul)[0]
@ -327,7 +330,7 @@ def plot_magnification(model, labels=None, which_indices=None,
else:
x = X[index, input_1]
y = X[index, input_2]
ax.scatter(x, y, marker=m, s=s, color=Tango.nextMedium(), label=this_label)
ax.scatter(x, y, marker=m, s=s, c=Tango.nextMedium(), label=this_label, linewidth=.2, edgecolor='k', alpha=.9)
ax.set_xlabel('latent dimension %i' % input_1)
ax.set_ylabel('latent dimension %i' % input_2)
@ -337,18 +340,27 @@ def plot_magnification(model, labels=None, which_indices=None,
ax.set_xlim((xmin, xmax))
ax.set_ylim((ymin, ymax))
ax.grid(b=False) # remove the grid if present, it doesn't look good
ax.set_aspect('auto') # set a nice aspect ratio
if plot_inducing:
if plot_inducing and hasattr(model, 'Z'):
Z = model.Z
ax.scatter(Z[:, input_1], Z[:, input_2], c='w', s=18, marker="^", edgecolor='k', linewidth=.3, alpha=.7)
if updates:
fig.canvas.show()
raw_input('Enter to continue')
try:
fig.canvas.draw()
fig.tight_layout()
fig.canvas.draw()
except Exception as e:
print("Could not invoke tight layout: {}".format(e))
pass
pb.title('Magnification Factor')
if updates:
try:
fig.canvas.draw()
fig.canvas.show()
except Exception as e:
print("Could not invoke show: {}".format(e))
#raw_input('Enter to continue')
return view
return ax