num_data refactoring

This commit is contained in:
Max Zwiessele 2013-06-05 17:33:22 +01:00
parent efbf169a6a
commit a6ed003194
3 changed files with 21 additions and 22 deletions

View file

@ -28,8 +28,7 @@ class GPBase(Model):
self._Xmean = np.zeros((1, self.input_dim)) self._Xmean = np.zeros((1, self.input_dim))
self._Xstd = np.ones((1, self.input_dim)) self._Xstd = np.ones((1, self.input_dim))
Model.__init__(self) super(GPBase, self).__init__()
# All leaf nodes should call self._set_params(self._get_params()) at # All leaf nodes should call self._set_params(self._get_params()) at
# the end # the end

View file

@ -2,9 +2,9 @@ import pylab as pb
import numpy as np import numpy as np
from .. import util from .. import util
def plot_latent(Model, labels=None, which_indices=None, resolution=50, ax=None, marker='o', s=40): def plot_latent(model, labels=None, which_indices=None, resolution=50, ax=None, marker='o', s=40):
""" """
:param labels: a np.array of size Model.N containing labels for the points (can be number, strings, etc) :param labels: a np.array of size model.num_data containing labels for the points (can be number, strings, etc)
:param resolution: the resolution of the grid on which to evaluate the predictive variance :param resolution: the resolution of the grid on which to evaluate the predictive variance
""" """
if ax is None: if ax is None:
@ -12,26 +12,26 @@ def plot_latent(Model, labels=None, which_indices=None, resolution=50, ax=None,
util.plot.Tango.reset() util.plot.Tango.reset()
if labels is None: if labels is None:
labels = np.ones(Model.N) labels = np.ones(model.num_data)
if which_indices is None: if which_indices is None:
if Model.input_dim==1: if model.input_dim==1:
input_1 = 0 input_1 = 0
input_2 = None input_2 = None
if Model.input_dim==2: if model.input_dim==2:
input_1, input_2 = 0,1 input_1, input_2 = 0,1
else: else:
try: try:
input_1, input_2 = np.argsort(Model.input_sensitivity())[:2] input_1, input_2 = np.argsort(model.input_sensitivity())[:2]
except: except:
raise ValueError, "cannot Atomatically determine which dimensions to plot, please pass 'which_indices'" raise ValueError, "cannot Atomatically determine which dimensions to plot, please pass 'which_indices'"
else: else:
input_1, input_2 = which_indices input_1, input_2 = which_indices
#first, plot the output variance as a function of the latent space #first, plot the output variance as a function of the latent space
Xtest, xx,yy,xmin,xmax = util.plot.x_frame2D(Model.X[:,[input_1, input_2]],resolution=resolution) Xtest, xx,yy,xmin,xmax = util.plot.x_frame2D(model.X[:,[input_1, input_2]],resolution=resolution)
Xtest_full = np.zeros((Xtest.shape[0], Model.X.shape[1])) Xtest_full = np.zeros((Xtest.shape[0], model.X.shape[1]))
Xtest_full[:, :2] = Xtest Xtest_full[:, :2] = Xtest
mu, var, low, up = Model.predict(Xtest_full) mu, var, low, up = model.predict(Xtest_full)
var = var[:, :1] var = var[:, :1]
ax.imshow(var.reshape(resolution, resolution).T, ax.imshow(var.reshape(resolution, resolution).T,
extent=[xmin[0], xmax[0], xmin[1], xmax[1]], cmap=pb.cm.binary,interpolation='bilinear',origin='lower') extent=[xmin[0], xmax[0], xmin[1], xmax[1]], cmap=pb.cm.binary,interpolation='bilinear',origin='lower')
@ -55,12 +55,12 @@ def plot_latent(Model, labels=None, which_indices=None, resolution=50, ax=None,
m = marker m = marker
index = np.nonzero(labels==ul)[0] index = np.nonzero(labels==ul)[0]
if Model.input_dim==1: if model.input_dim==1:
x = Model.X[index,input_1] x = model.X[index,input_1]
y = np.zeros(index.size) y = np.zeros(index.size)
else: else:
x = Model.X[index,input_1] x = model.X[index,input_1]
y = Model.X[index,input_2] y = model.X[index,input_2]
ax.scatter(x, y, marker=m, s=s, color=util.plot.Tango.nextMedium(), label=this_label) ax.scatter(x, y, marker=m, s=s, color=util.plot.Tango.nextMedium(), label=this_label)
ax.set_xlabel('latent dimension %i'%input_1) ax.set_xlabel('latent dimension %i'%input_1)
@ -88,4 +88,4 @@ def plot_latent_indices(Model, which_indices=None, *args, **kwargs):
ax = plot_latent(Model, which_indices=[input_1, input_2], *args, **kwargs) ax = plot_latent(Model, which_indices=[input_1, input_2], *args, **kwargs)
# TODO: Here test if there are inducing points... # TODO: Here test if there are inducing points...
ax.plot(Model.Z[:, input_1], Model.Z[:, input_2], '^w') ax.plot(Model.Z[:, input_1], Model.Z[:, input_2], '^w')
return ax return ax

View file

@ -43,16 +43,16 @@ class vector_show(data_show):
class lvm(data_show): class lvm(data_show):
def __init__(self, vals, Model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1]): def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1]):
"""Visualize a latent variable Model """Visualize a latent variable model
:param Model: the latent variable Model to visualize. :param model: the latent variable model to visualize.
:param data_visualize: the object used to visualize the data which has been modelled. :param data_visualize: the object used to visualize the data which has been modelled.
:type data_visualize: visualize.data_show type. :type data_visualize: visualize.data_show type.
:param latent_axes: the axes where the latent visualization should be plotted. :param latent_axes: the axes where the latent visualization should be plotted.
""" """
if vals == None: if vals == None:
vals = Model.X[0] vals = model.X[0]
data_show.__init__(self, vals, axes=latent_axes) data_show.__init__(self, vals, axes=latent_axes)
@ -68,13 +68,13 @@ class lvm(data_show):
self.cid = latent_axes[0].figure.canvas.mpl_connect('axes_enter_event', self.on_enter) self.cid = latent_axes[0].figure.canvas.mpl_connect('axes_enter_event', self.on_enter)
self.data_visualize = data_visualize self.data_visualize = data_visualize
self.Model = Model self.Model = model
self.latent_axes = latent_axes self.latent_axes = latent_axes
self.sense_axes = sense_axes self.sense_axes = sense_axes
self.called = False self.called = False
self.move_on = False self.move_on = False
self.latent_index = latent_index self.latent_index = latent_index
self.latent_dim = Model.input_dim self.latent_dim = model.input_dim
# The red cross which shows current latent point. # The red cross which shows current latent point.
self.latent_values = vals self.latent_values = vals