GPy/test6.py
2014-06-04 11:21:25 +01:00

95 lines
3.1 KiB
Python

from matplotlib import pyplot as plt
import GPy
import numpy as np
class lvm_visualise:
def __init__(self, visualise, ax):
self.cid = ax.figure.canvas.mpl_connect('button_press_event', self.on_click)
self.cid = ax.figure.canvas.mpl_connect('motion_notify_event', self.on_move)
self.visualise = visualise
self.ax = ax
# This is vectorDisplay code
self.called = False
self.move_on = False
def on_click(self, event):
print 'click', event.xdata, event.ydata
if event.inaxes!=self.ax: return
self.move_on = not self.move_on
print
if self.called:
self.xs.append(event.xdata)
self.ys.append(event.ydata)
self.line.set_data(self.xs, self.ys)
self.line.figure.canvas.draw()
else:
self.xs = [event.xdata]
self.ys = [event.ydata]
self.line, = ax.plot(event.xdata, event.ydata)
self.called = True
def on_move(self, event):
if event.inaxes!=self.ax: return
if self.called and self.move_on:
# This is vectorModify code
#print 'move', event.xdata, event.ydata
latent_values = np.array((event.xdata, event.ydata))
self.visualise.modify(latent_values)
#print 'y', y
class data_visualiser:
def __init__(self, model):
self.model = model
def modify(self, latent_values):
raise NotImplementedError, "this needs to be implemented to use the data_visualiser class"
class vector_visualise(data_visualiser):
def __init__(self, model):
data_visualiser.__init__(self, model)
self.fig_display = plt.figure()
self.model = model
self.y = model.predict(np.zeros((1, model.input_dim)))[0]
self.handle = plt.plot(np.arange(0, model.output_dim)[:, None], self.y.T)[0]
def modify(self, latent_values):
y = self.model.predict(latent_values)[0]
xdata, ydata = self.handle.get_data()
self.handle.set_data(xdata, y.T)
plt.show()
class image_visualise(data_visualiser):
def __init__(self, model, dimensions=(16,16), transpose=False, invert=False):
data_visualiser.__init__(self, model)
self.fig_display = plt.figure()
self.model = model
self.imvals = self.get_array(model, np.zeros((1, model.input_dim)))
self.handle = plt.imshow(self.imvals)
self.transpose = transpose
self.invert = invert
def modify(self, latent_values):
xdata, ydata = self.handle.get_data()
self.imvals = self.get_array(model, latent_values)
self.handle.set_array(self.imvals)
plt.show()
def get_array(self, latent_values):
self.y = self.model.predict(latent_values)[0]
imvals = np.reshape(self.y, dimensions)
if self.transpose:
imvals = imvals.T
if self.invert:
imvals = -imvals
model = GPy.examples.dimensionality_reduction.oil_100()
visualise = vector_visualise(model)
fig = plt.figure()
ax = fig.add_subplot(111)
plt.ylim((-1, 1))
plt.xlim((-1, 1))
ax.set_title('latent space')
linebuilder = lvm_visualise(visualise, ax)
plt.show()