mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-11 15:15:15 +02:00
Merge changes.
This commit is contained in:
parent
a500c526dd
commit
c6ee7cd98f
2 changed files with 48 additions and 33 deletions
|
|
@ -133,8 +133,8 @@ class GP(model):
|
|||
|
||||
"""
|
||||
Kx = self.kern.K(self.X, _Xnew,which_parts=which_parts)
|
||||
mu = np.dot(np.dot(Kx.T, self.Ki), self.likelihood.Y)
|
||||
KiKx = np.dot(self.Ki, Kx)
|
||||
mu = np.dot(KiKx.T, self.likelihood.Y)
|
||||
if full_cov:
|
||||
Kxx = self.kern.K(_Xnew, which_parts=which_parts)
|
||||
var = Kxx - np.dot(KiKx.T, Kx)
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class vector_show(data_show):
|
|||
|
||||
|
||||
class lvm(data_show):
|
||||
def __init__(self, vals, model, data_visualize, latent_axes=None, latent_index=[0,1]):
|
||||
def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1]):
|
||||
"""Visualize a latent variable model
|
||||
|
||||
:param model: the latent variable model to visualize.
|
||||
|
|
@ -70,7 +70,7 @@ class lvm(data_show):
|
|||
self.data_visualize = data_visualize
|
||||
self.model = model
|
||||
self.latent_axes = latent_axes
|
||||
|
||||
self.sense_axes = sense_axes
|
||||
self.called = False
|
||||
self.move_on = False
|
||||
self.latent_index = latent_index
|
||||
|
|
@ -80,11 +80,13 @@ class lvm(data_show):
|
|||
self.latent_values = vals
|
||||
self.latent_handle = self.latent_axes.plot([0],[0],'rx',mew=2)[0]
|
||||
self.modify(vals)
|
||||
self.show_sensitivities()
|
||||
|
||||
def modify(self, vals):
|
||||
"""When latent values are modified update the latent representation and ulso update the output visualization."""
|
||||
|
||||
y = self.model.predict(vals)[0]
|
||||
print y
|
||||
self.data_visualize.modify(y)
|
||||
self.latent_handle.set_data(vals[self.latent_index[0]], vals[self.latent_index[1]])
|
||||
self.axes.figure.canvas.draw()
|
||||
|
|
@ -99,6 +101,7 @@ class lvm(data_show):
|
|||
if event.inaxes!=self.latent_axes: return
|
||||
self.move_on = not self.move_on
|
||||
self.called = True
|
||||
|
||||
def on_move(self, event):
|
||||
if event.inaxes!=self.latent_axes: return
|
||||
if self.called and self.move_on:
|
||||
|
|
@ -107,22 +110,54 @@ class lvm(data_show):
|
|||
self.latent_values[self.latent_index[1]]=event.ydata
|
||||
self.modify(self.latent_values)
|
||||
|
||||
def show_sensitivities(self):
|
||||
# A click in the bar chart axis for selection a dimension.
|
||||
if self.sense_axes != None:
|
||||
self.sense_axes.cla()
|
||||
self.sense_axes.bar(np.arange(self.model.Q),1./self.model.input_sensitivity(),color='b')
|
||||
|
||||
if self.latent_index[1] == self.latent_index[0]:
|
||||
self.sense_axes.bar(np.array(self.latent_index[0]),1./self.model.input_sensitivity()[self.latent_index[0]],color='y')
|
||||
self.sense_axes.bar(np.array(self.latent_index[1]),1./self.model.input_sensitivity()[self.latent_index[1]],color='y')
|
||||
|
||||
else:
|
||||
self.sense_axes.bar(np.array(self.latent_index[0]),1./self.model.input_sensitivity()[self.latent_index[0]],color='g')
|
||||
self.sense_axes.bar(np.array(self.latent_index[1]),1./self.model.input_sensitivity()[self.latent_index[1]],color='r')
|
||||
|
||||
self.sense_axes.figure.canvas.draw()
|
||||
|
||||
|
||||
class lvm_subplots(lvm):
|
||||
"""
|
||||
latent_axes is a np array of dimension np.ceil(Q/2) + 1,
|
||||
one for each pair of the axes, and the last one for the sensitiity bar chart
|
||||
latent_axes is a np array of dimension np.ceil(Q/2),
|
||||
one for each pair of the latent dimensions.
|
||||
"""
|
||||
def __init__(self, vals, model, data_visualize, latent_axes=None, latent_index=[0,1]):
|
||||
lvm.__init__(self, vals, model,data_visualize,latent_axes,[0,1])
|
||||
def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None):
|
||||
self.nplots = int(np.ceil(model.Q/2.))+1
|
||||
lvm.__init__(self,model,data_visualize,latent_axes,latent_index)
|
||||
self.latent_values = np.zeros(2*np.ceil(self.model.Q/2.)) # possibly an extra dimension on this
|
||||
assert latent_axes.size == self.nplots
|
||||
assert len(latent_axes)==self.nplots
|
||||
if vals==None:
|
||||
vals = model.X[0, :]
|
||||
self.latent_values = vals
|
||||
|
||||
for i, axis in enumerate(latent_axes):
|
||||
if i == self.nplots-1:
|
||||
if self.nplots*2!=model.Q:
|
||||
latent_index = [i*2, i*2]
|
||||
lvm.__init__(self, self.latent_vals, model, data_visualize, axis, sense_axes, latent_index=latent_index)
|
||||
else:
|
||||
latent_index = [i*2, i*2+1]
|
||||
lvm.__init__(self, self.latent_vals, model, data_visualize, axis, latent_index=latent_index)
|
||||
|
||||
|
||||
|
||||
class lvm_dimselect(lvm):
|
||||
"""
|
||||
A visualizer for latent variable models which allows selection of the latent dimensions to use by clicking on a bar chart of their length scales.
|
||||
|
||||
For an example of the visualizer's use try:
|
||||
|
||||
GPy.examples.dimensionality_reduction.BGPVLM_oil()
|
||||
|
||||
"""
|
||||
def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0, 1]):
|
||||
if latent_axes==None and sense_axes==None:
|
||||
|
|
@ -133,24 +168,9 @@ class lvm_dimselect(lvm):
|
|||
else:
|
||||
self.sense_axes = sense_axes
|
||||
|
||||
lvm.__init__(self,vals,model,data_visualize,latent_axes,latent_index)
|
||||
self.show_sensitivities()
|
||||
lvm.__init__(self,vals,model,data_visualize,latent_axes,sense_axes,latent_index)
|
||||
print "use left and right mouse butons to select dimensions"
|
||||
|
||||
def show_sensitivities(self):
|
||||
# A click in the bar chart axis for selection a dimension.
|
||||
self.sense_axes.cla()
|
||||
self.sense_axes.bar(np.arange(self.model.Q),1./self.model.input_sensitivity(),color='b')
|
||||
|
||||
if self.latent_index[1] == self.latent_index[0]:
|
||||
self.sense_axes.bar(np.array(self.latent_index[0]),1./self.model.input_sensitivity()[self.latent_index[0]],color='y')
|
||||
self.sense_axes.bar(np.array(self.latent_index[1]),1./self.model.input_sensitivity()[self.latent_index[1]],color='y')
|
||||
|
||||
else:
|
||||
self.sense_axes.bar(np.array(self.latent_index[0]),1./self.model.input_sensitivity()[self.latent_index[0]],color='g')
|
||||
self.sense_axes.bar(np.array(self.latent_index[1]),1./self.model.input_sensitivity()[self.latent_index[1]],color='r')
|
||||
|
||||
self.sense_axes.figure.canvas.draw()
|
||||
|
||||
def on_click(self, event):
|
||||
|
||||
|
|
@ -177,12 +197,6 @@ class lvm_dimselect(lvm):
|
|||
self.called = True
|
||||
|
||||
|
||||
def on_move(self, event):
|
||||
if event.inaxes!=self.latent_axes: return
|
||||
if self.called and self.move_on:
|
||||
self.latent_values[self.latent_index[0]]=event.xdata
|
||||
self.latent_values[self.latent_index[1]]=event.ydata
|
||||
self.modify(self.latent_values)
|
||||
|
||||
def on_leave(self,event):
|
||||
latent_values = self.latent_values.copy()
|
||||
|
|
@ -305,7 +319,8 @@ class stick_show(mocap_data_show):
|
|||
|
||||
def process_values(self, vals):
|
||||
self.vals = vals.reshape((3, vals.shape[1]/3)).T
|
||||
|
||||
print vals
|
||||
|
||||
class skeleton_show(mocap_data_show):
|
||||
"""data_show class for visualizing motion capture data encoded as a skeleton with angles."""
|
||||
def __init__(self, vals, skel, padding=0, axes=None):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue