some minor improvements in visualize

This commit is contained in:
James Hensman 2013-04-10 20:02:22 +01:00
parent 87304a0778
commit 48b0ac6399
4 changed files with 19 additions and 16 deletions

View file

@ -60,7 +60,7 @@ def GPLVM_oil_100(optimize=True,M=15):
m.plot_latent(labels=m.data_labels) m.plot_latent(labels=m.data_labels)
return m return m
def BGPLVM_oil(optimize=True,N=100,Q=10,M=15): def BGPLVM_oil(optimize=True,N=100,Q=10,M=15,max_f_eval=300):
data = GPy.util.datasets.oil() data = GPy.util.datasets.oil()
# create simple GP model # create simple GP model
@ -72,10 +72,10 @@ def BGPLVM_oil(optimize=True,N=100,Q=10,M=15):
if optimize: if optimize:
m.constrain_fixed('noise',0.05) m.constrain_fixed('noise',0.05)
m.ensure_default_constraints() m.ensure_default_constraints()
m.optimize('scg',messages=1) m.optimize('scg',messages=1,max_f_eval=max(80,max_f_eval))
m.unconstrain('noise') m.unconstrain('noise')
m.constrain_positive('noise') m.constrain_positive('noise')
m.optimize('scg',messages=1) m.optimize('scg',messages=1,max_f_eval=max(0,max_f_eval-80))
else: else:
m.ensure_default_constraints() m.ensure_default_constraints()

View file

@ -173,7 +173,7 @@ class rbf(kernpart):
"""Think N,M,M,Q """ """Think N,M,M,Q """
self._psi_computations(Z,mu,S) self._psi_computations(Z,mu,S)
tmp = self._psi2[:,:,:,None]/self.lengthscale2/self._psi2_denom tmp = self._psi2[:,:,:,None]/self.lengthscale2/self._psi2_denom
target_mu += (dL_dpsi2[:,:,:,None]*-tmp*2.*self._psi2_mudist).sum(1).sum(1) target_mu += -2.*(dL_dpsi2[:,:,:,None]*tmp*self._psi2_mudist).sum(1).sum(1)
target_S += (dL_dpsi2[:,:,:,None]*tmp*(2.*self._psi2_mudist_sq-1)).sum(1).sum(1) target_S += (dL_dpsi2[:,:,:,None]*tmp*(2.*self._psi2_mudist_sq-1)).sum(1).sum(1)
@ -207,7 +207,6 @@ class rbf(kernpart):
if not (np.all(Z==self._Z) and np.all(mu==self._mu) and np.all(S==self._S)): if not (np.all(Z==self._Z) and np.all(mu==self._mu) and np.all(S==self._S)):
#something's changed. recompute EVERYTHING #something's changed. recompute EVERYTHING
#TODO: make more efficient for large Q (using NDL's dot product trick)
#psi1 #psi1
self._psi1_denom = S[:,None,:]/self.lengthscale2 + 1. self._psi1_denom = S[:,None,:]/self.lengthscale2 + 1.
self._psi1_dist = Z[None,:,:]-mu[:,None,:] self._psi1_dist = Z[None,:,:]-mu[:,None,:]

View file

@ -95,3 +95,4 @@ class Bayesian_GPLVM(sparse_GP, GPLVM):
input_1, input_2 = which_indices input_1, input_2 = which_indices
ax = GPLVM.plot_latent(self, which_indices=[input_1, input_2],*args, **kwargs) ax = GPLVM.plot_latent(self, which_indices=[input_1, input_2],*args, **kwargs)
ax.plot(self.Z[:, input_1], self.Z[:, input_2], '^w') ax.plot(self.Z[:, input_1], self.Z[:, input_2], '^w')
return ax

View file

@ -4,7 +4,7 @@ import GPy
import numpy as np import numpy as np
class lvm: class lvm:
def __init__(self, model, data_visualize, latent_axis): def __init__(self, model, data_visualize, latent_axis, latent_index=[0,1], latent_dim=2):
self.cid = latent_axis.figure.canvas.mpl_connect('button_press_event', self.on_click) self.cid = latent_axis.figure.canvas.mpl_connect('button_press_event', self.on_click)
self.cid = latent_axis.figure.canvas.mpl_connect('motion_notify_event', self.on_move) self.cid = latent_axis.figure.canvas.mpl_connect('motion_notify_event', self.on_move)
self.data_visualize = data_visualize self.data_visualize = data_visualize
@ -12,6 +12,8 @@ class lvm:
self.latent_axis = latent_axis self.latent_axis = latent_axis
self.called = False self.called = False
self.move_on = False self.move_on = False
self.latent_index = latent_index
self.latent_dim = latent_dim
def on_click(self, event): def on_click(self, event):
#print 'click', event.xdata, event.ydata #print 'click', event.xdata, event.ydata
@ -32,7 +34,8 @@ class lvm:
if self.called and self.move_on: if self.called and self.move_on:
# Call modify code on move # Call modify code on move
#print 'move', event.xdata, event.ydata #print 'move', event.xdata, event.ydata
latent_values = np.array((event.xdata, event.ydata)) latent_values = np.zeros((1,self.latent_dim))
latent_values[0,self.latent_index] = np.array([event.xdata, event.ydata])
y = self.model.predict(latent_values)[0] y = self.model.predict(latent_values)[0]
self.data_visualize.modify(y) self.data_visualize.modify(y)
#print 'y', y #print 'y', y
@ -57,7 +60,7 @@ class vector_show(data_show):
def __init__(self, vals, axis=None): def __init__(self, vals, axis=None):
data_show.__init__(self, vals, axis) data_show.__init__(self, vals, axis)
self.vals = vals.T self.vals = vals.T
self.handle = plt.plot(np.arange(0, len(vals))[:, None], self.vals)[0] self.handle = self.axis.plot(np.arange(0, len(vals))[:, None], self.vals)[0]
def modify(self, vals): def modify(self, vals):
xdata, ydata = self.handle.get_data() xdata, ydata = self.handle.get_data()