[magnification] plot_magnification expanded

This commit is contained in:
Max Zwiessele 2015-09-03 09:33:07 +01:00
parent ca60ad3195
commit 839e3dc6f0
8 changed files with 64 additions and 60 deletions

View file

@ -371,7 +371,7 @@ class GP(Model):
var_jac = compute_cov_inner(self.posterior.woodbury_inv) var_jac = compute_cov_inner(self.posterior.woodbury_inv)
return mean_jac, var_jac return mean_jac, var_jac
def predict_wishard_embedding(self, Xnew, kern=None): def predict_wishard_embedding(self, Xnew, kern=None, mean=True, covariance=True):
""" """
Predict the wishard embedding G of the GP. This is the density of the Predict the wishard embedding G of the GP. This is the density of the
input of the GP defined by the probabilistic function mapping f. input of the GP defined by the probabilistic function mapping f.
@ -391,13 +391,16 @@ class GP(Model):
mumuT = np.einsum('iqd,ipd->iqp', mu_jac, mu_jac) mumuT = np.einsum('iqd,ipd->iqp', mu_jac, mu_jac)
if var_jac.ndim == 3: if var_jac.ndim == 3:
Sigma = np.einsum('iqd,ipd->iqp', var_jac, var_jac) Sigma = np.einsum('iqd,ipd->iqp', var_jac, var_jac)
G = mumuT + Sigma
else: else:
Sigma = np.einsum('iq,ip->iqp', var_jac, var_jac) Sigma = self.output_dim*np.einsum('iq,ip->iqp', var_jac, var_jac)
G = mumuT + self.output_dim*Sigma G = 0.
if mean:
G += mumuT
if covariance:
G += Sigma
return G return G
def predict_magnification(self, Xnew, kern=None): def predict_magnification(self, Xnew, kern=None, mean=True, covariance=True):
""" """
Predict the magnification factor as Predict the magnification factor as
@ -405,7 +408,7 @@ class GP(Model):
for each point N in Xnew for each point N in Xnew
""" """
G = self.predict_wishard_embedding(Xnew, kern) G = self.predict_wishard_embedding(Xnew, kern, mean, covariance)
from ..util.linalg import jitchol from ..util.linalg import jitchol
return np.array([np.sqrt(np.exp(2*np.sum(np.log(np.diag(jitchol(G[n, :, :])))))) for n in range(Xnew.shape[0])]) return np.array([np.sqrt(np.exp(2*np.sum(np.log(np.diag(jitchol(G[n, :, :])))))) for n in range(Xnew.shape[0])])
#return np.array([np.sqrt(np.linalg.det(G[n, :, :])) for n in range(Xnew.shape[0])]) #return np.array([np.sqrt(np.linalg.det(G[n, :, :])) for n in range(Xnew.shape[0])])
@ -569,7 +572,7 @@ class GP(Model):
resolution=50, ax=None, marker='o', s=40, resolution=50, ax=None, marker='o', s=40,
fignum=None, legend=True, fignum=None, legend=True,
plot_limits=None, plot_limits=None,
aspect='auto', updates=False, **kwargs): aspect='auto', updates=False, plot_inducing=True, kern=None, **kwargs):
import sys import sys
assert "matplotlib" in sys.modules, "matplotlib package has not been imported." assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
@ -577,7 +580,7 @@ class GP(Model):
return dim_reduction_plots.plot_magnification(self, labels, which_indices, return dim_reduction_plots.plot_magnification(self, labels, which_indices,
resolution, ax, marker, s, resolution, ax, marker, s,
fignum, False, legend, fignum, plot_inducing, legend,
plot_limits, aspect, updates, **kwargs) plot_limits, aspect, updates, **kwargs)

View file

@ -181,18 +181,3 @@ class SparseGP(GP):
var[i] = np.diag(var_)+p0-t2 var[i] = np.diag(var_)+p0-t2
return mu, var return mu, var
def plot_magnification(self, labels=None, which_indices=None,
resolution=50, ax=None, marker='o', s=40,
fignum=None, legend=True,
plot_limits=None,
aspect='auto', updates=False, plot_inducing=True, **kwargs):
import sys
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import dim_reduction_plots
return dim_reduction_plots.plot_magnification(self, labels, which_indices,
resolution, ax, marker, s,
fignum, plot_inducing, legend,
plot_limits, aspect, updates, **kwargs)

View file

@ -14,7 +14,7 @@ class Add(CombinationKernel):
This kernel will take over the active dims of it's subkernels passed in. This kernel will take over the active dims of it's subkernels passed in.
""" """
def __init__(self, subkerns, name='add'): def __init__(self, subkerns, name='sum'):
for i, kern in enumerate(subkerns[:]): for i, kern in enumerate(subkerns[:]):
if isinstance(kern, Add): if isinstance(kern, Add):
del subkerns[i] del subkerns[i]
@ -72,6 +72,16 @@ class Add(CombinationKernel):
[target.__iadd__(p.gradients_X_diag(dL_dKdiag, X)) for p in self.parts] [target.__iadd__(p.gradients_X_diag(dL_dKdiag, X)) for p in self.parts]
return target return target
def gradients_XX(self, dL_dK, X, X2):
target = 0.
[target.__iadd__(p.gradients_XX(dL_dK, X, X2)) for p in self.parts]
return target
def gradients_XX_diag(self, dL_dKdiag, X):
target = np.zeros(X.shape)
[target.__iadd__(p.gradients_XX_diag(dL_dKdiag, X)) for p in self.parts]
return target
@Cache_this(limit=2, force_kwargs=['which_parts']) @Cache_this(limit=2, force_kwargs=['which_parts'])
def psi0(self, Z, variational_posterior): def psi0(self, Z, variational_posterior):
return reduce(np.add, (p.psi0(Z, variational_posterior) for p in self.parts)) return reduce(np.add, (p.psi0(Z, variational_posterior) for p in self.parts))

View file

@ -102,7 +102,7 @@ class Kern(Parameterized):
raise NotImplementedError raise NotImplementedError
def gradients_XX(self, dL_dK, X, X2): def gradients_XX(self, dL_dK, X, X2):
raise(NotImplementedError, "This is the second derivative of K wrt X and X2, and not implemented for this kernel") raise(NotImplementedError, "This is the second derivative of K wrt X and X2, and not implemented for this kernel")
def gradients_XX_diag(self, dL_dK, X, X2): def gradients_XX_diag(self, dL_dKdiag, X):
raise(NotImplementedError, "This is the diagonal of the second derivative of K wrt X and X2, and not implemented for this kernel") raise(NotImplementedError, "This is the diagonal of the second derivative of K wrt X and X2, and not implemented for this kernel")
def gradients_X_diag(self, dL_dKdiag, X): def gradients_X_diag(self, dL_dKdiag, X):
raise NotImplementedError raise NotImplementedError

View file

@ -24,6 +24,11 @@ class Static(Kern):
def gradients_X_diag(self, dL_dKdiag, X): def gradients_X_diag(self, dL_dKdiag, X):
return np.zeros(X.shape) return np.zeros(X.shape)
def gradients_XX(self, dL_dK, X, X2):
return np.zeros((X.shape[0], X2.shape[0], X.shape[1]), dtype=np.float64)
def gradients_XX_diag(self, dL_dKdiag, X):
return np.zeros(X.shape)
def gradients_Z_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior): def gradients_Z_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
return np.zeros(Z.shape) return np.zeros(Z.shape)

View file

@ -137,20 +137,6 @@ class BayesianGPLVM(SparseGP_MPI):
fignum, plot_inducing, legend, fignum, plot_inducing, legend,
plot_limits, aspect, updates, predict_kwargs, imshow_kwargs) plot_limits, aspect, updates, predict_kwargs, imshow_kwargs)
def plot_magnification(self, labels=None, which_indices=None,
resolution=50, ax=None, marker='o', s=40,
fignum=None, legend=True,
plot_limits=None,
aspect='auto', updates=False, **kwargs):
import sys
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import dim_reduction_plots
return dim_reduction_plots.plot_magnification(self, labels, which_indices,
resolution, ax, marker, s,
fignum, False, legend,
plot_limits, aspect, updates, **kwargs)
def do_test_latents(self, Y): def do_test_latents(self, Y):
""" """
Compute the latent representation for a set of new points Y Compute the latent representation for a set of new points Y

View file

@ -9,7 +9,8 @@ import itertools
try: try:
import Tango import Tango
from matplotlib.cm import get_cmap from matplotlib.cm import get_cmap
import pylab as pb from matplotlib import pyplot as pb
from matplotlib import cm
except: except:
pass pass
@ -137,7 +138,7 @@ def plot_latent(model, labels=None, which_indices=None,
view = ImshowController(ax, plot_function, view = ImshowController(ax, plot_function,
(xmin, ymin, xmax, ymax), (xmin, ymin, xmax, ymax),
resolution, aspect=aspect, interpolation='bilinear', resolution, aspect=aspect, interpolation='bilinear',
cmap=pb.cm.binary, **imshow_kwargs) cmap=cm.binary, **imshow_kwargs)
# make sure labels are in order of input: # make sure labels are in order of input:
labels = np.asarray(labels) labels = np.asarray(labels)
@ -192,18 +193,18 @@ def plot_latent(model, labels=None, which_indices=None,
if updates: if updates:
try: try:
ax.figure.canvas.show() fig.canvas.show()
except Exception as e: except Exception as e:
print("Could not invoke show: {}".format(e)) print("Could not invoke show: {}".format(e))
raw_input('Enter to continue') #raw_input('Enter to continue')
view.deactivate() return view
return ax return ax
def plot_magnification(model, labels=None, which_indices=None, def plot_magnification(model, labels=None, which_indices=None,
resolution=60, ax=None, marker='o', s=40, resolution=60, ax=None, marker='o', s=40,
fignum=None, plot_inducing=False, legend=True, fignum=None, plot_inducing=False, legend=True,
plot_limits=None, plot_limits=None,
aspect='auto', updates=False): aspect='auto', updates=False, mean=True, covariance=True, kern=None):
""" """
:param labels: a np.array of size model.num_data containing labels for the points (can be number, strings, etc) :param labels: a np.array of size model.num_data containing labels for the points (can be number, strings, etc)
:param resolution: the resolution of the grid on which to evaluate the predictive variance :param resolution: the resolution of the grid on which to evaluate the predictive variance
@ -211,6 +212,8 @@ def plot_magnification(model, labels=None, which_indices=None,
if ax is None: if ax is None:
fig = pb.figure(num=fignum) fig = pb.figure(num=fignum)
ax = fig.add_subplot(111) ax = fig.add_subplot(111)
else:
fig = ax.figure
Tango.reset() Tango.reset()
if labels is None: if labels is None:
@ -295,13 +298,13 @@ def plot_magnification(model, labels=None, which_indices=None,
def plot_function(x): def plot_function(x):
Xtest_full = np.zeros((x.shape[0], X.shape[1])) Xtest_full = np.zeros((x.shape[0], X.shape[1]))
Xtest_full[:, [input_1, input_2]] = x Xtest_full[:, [input_1, input_2]] = x
mf = model.predict_magnification(Xtest_full) mf = model.predict_magnification(Xtest_full, kern=kern, mean=mean, covariance=covariance)
return mf return mf
view = ImshowController(ax, plot_function, view = ImshowController(ax, plot_function,
(xmin, ymin, xmax, ymax), (xmin, ymin, xmax, ymax),
resolution, aspect=aspect, interpolation='bilinear', resolution, aspect=aspect, interpolation='bilinear',
cmap=pb.cm.gray) cmap=cm.gray)
# make sure labels are in order of input: # make sure labels are in order of input:
ulabels = [] ulabels = []
@ -317,7 +320,7 @@ def plot_magnification(model, labels=None, which_indices=None,
elif type(ul) is np.int64: elif type(ul) is np.int64:
this_label = 'class %i' % ul this_label = 'class %i' % ul
else: else:
this_label = 'class %i' % i this_label = unicode(ul)
m = marker.next() m = marker.next()
index = np.nonzero(labels == ul)[0] index = np.nonzero(labels == ul)[0]
@ -327,7 +330,7 @@ def plot_magnification(model, labels=None, which_indices=None,
else: else:
x = X[index, input_1] x = X[index, input_1]
y = X[index, input_2] y = X[index, input_2]
ax.scatter(x, y, marker=m, s=s, color=Tango.nextMedium(), label=this_label) ax.scatter(x, y, marker=m, s=s, c=Tango.nextMedium(), label=this_label, linewidth=.2, edgecolor='k', alpha=.9)
ax.set_xlabel('latent dimension %i' % input_1) ax.set_xlabel('latent dimension %i' % input_1)
ax.set_ylabel('latent dimension %i' % input_2) ax.set_ylabel('latent dimension %i' % input_2)
@ -337,18 +340,27 @@ def plot_magnification(model, labels=None, which_indices=None,
ax.set_xlim((xmin, xmax)) ax.set_xlim((xmin, xmax))
ax.set_ylim((ymin, ymax)) ax.set_ylim((ymin, ymax))
ax.grid(b=False) # remove the grid if present, it doesn't look good
ax.set_aspect('auto') # set a nice aspect ratio
if plot_inducing: if plot_inducing and hasattr(model, 'Z'):
Z = model.Z Z = model.Z
ax.scatter(Z[:, input_1], Z[:, input_2], c='w', s=18, marker="^", edgecolor='k', linewidth=.3, alpha=.7) ax.scatter(Z[:, input_1], Z[:, input_2], c='w', s=18, marker="^", edgecolor='k', linewidth=.3, alpha=.7)
if updates: try:
fig.canvas.show() fig.canvas.draw()
raw_input('Enter to continue') fig.tight_layout()
fig.canvas.draw()
except Exception as e:
print("Could not invoke tight layout: {}".format(e))
pass
pb.title('Magnification Factor') if updates:
try:
fig.canvas.draw()
fig.canvas.show()
except Exception as e:
print("Could not invoke show: {}".format(e))
#raw_input('Enter to continue')
return view
return ax return ax

View file

@ -9,6 +9,9 @@ class AxisEventController(object):
def __init__(self, ax): def __init__(self, ax):
self.ax = ax self.ax = ax
self.activate() self.activate()
def __del__(self):
self.deactivate()
return self
def deactivate(self): def deactivate(self):
for cb_class in self.ax.callbacks.callbacks.values(): for cb_class in self.ax.callbacks.callbacks.values():
for cb_num in cb_class.keys(): for cb_num in cb_class.keys():