num_data refactoring

This commit is contained in:
Max Zwiessele 2013-06-05 17:33:22 +01:00
parent efbf169a6a
commit a6ed003194
3 changed files with 21 additions and 22 deletions

View file

@ -28,8 +28,7 @@ class GPBase(Model):
self._Xmean = np.zeros((1, self.input_dim))
self._Xstd = np.ones((1, self.input_dim))
Model.__init__(self)
super(GPBase, self).__init__()
# All leaf nodes should call self._set_params(self._get_params()) at
# the end

View file

@ -2,9 +2,9 @@ import pylab as pb
import numpy as np
from .. import util
def plot_latent(Model, labels=None, which_indices=None, resolution=50, ax=None, marker='o', s=40):
def plot_latent(model, labels=None, which_indices=None, resolution=50, ax=None, marker='o', s=40):
"""
:param labels: a np.array of size Model.N containing labels for the points (can be number, strings, etc)
:param labels: a np.array of size model.num_data containing labels for the points (can be number, strings, etc)
:param resolution: the resolution of the grid on which to evaluate the predictive variance
"""
if ax is None:
@ -12,26 +12,26 @@ def plot_latent(Model, labels=None, which_indices=None, resolution=50, ax=None,
util.plot.Tango.reset()
if labels is None:
labels = np.ones(Model.N)
labels = np.ones(model.num_data)
if which_indices is None:
if Model.input_dim==1:
if model.input_dim==1:
input_1 = 0
input_2 = None
if Model.input_dim==2:
if model.input_dim==2:
input_1, input_2 = 0,1
else:
try:
input_1, input_2 = np.argsort(Model.input_sensitivity())[:2]
input_1, input_2 = np.argsort(model.input_sensitivity())[:2]
except:
raise ValueError, "cannot Atomatically determine which dimensions to plot, please pass 'which_indices'"
else:
input_1, input_2 = which_indices
#first, plot the output variance as a function of the latent space
Xtest, xx,yy,xmin,xmax = util.plot.x_frame2D(Model.X[:,[input_1, input_2]],resolution=resolution)
Xtest_full = np.zeros((Xtest.shape[0], Model.X.shape[1]))
Xtest, xx,yy,xmin,xmax = util.plot.x_frame2D(model.X[:,[input_1, input_2]],resolution=resolution)
Xtest_full = np.zeros((Xtest.shape[0], model.X.shape[1]))
Xtest_full[:, :2] = Xtest
mu, var, low, up = Model.predict(Xtest_full)
mu, var, low, up = model.predict(Xtest_full)
var = var[:, :1]
ax.imshow(var.reshape(resolution, resolution).T,
extent=[xmin[0], xmax[0], xmin[1], xmax[1]], cmap=pb.cm.binary,interpolation='bilinear',origin='lower')
@ -55,12 +55,12 @@ def plot_latent(Model, labels=None, which_indices=None, resolution=50, ax=None,
m = marker
index = np.nonzero(labels==ul)[0]
if Model.input_dim==1:
x = Model.X[index,input_1]
if model.input_dim==1:
x = model.X[index,input_1]
y = np.zeros(index.size)
else:
x = Model.X[index,input_1]
y = Model.X[index,input_2]
x = model.X[index,input_1]
y = model.X[index,input_2]
ax.scatter(x, y, marker=m, s=s, color=util.plot.Tango.nextMedium(), label=this_label)
ax.set_xlabel('latent dimension %i'%input_1)
@ -88,4 +88,4 @@ def plot_latent_indices(Model, which_indices=None, *args, **kwargs):
ax = plot_latent(Model, which_indices=[input_1, input_2], *args, **kwargs)
# TODO: Here test if there are inducing points...
ax.plot(Model.Z[:, input_1], Model.Z[:, input_2], '^w')
return ax
return ax

View file

@ -43,16 +43,16 @@ class vector_show(data_show):
class lvm(data_show):
def __init__(self, vals, Model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1]):
"""Visualize a latent variable Model
def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1]):
"""Visualize a latent variable model
:param Model: the latent variable Model to visualize.
:param model: the latent variable model to visualize.
:param data_visualize: the object used to visualize the data which has been modelled.
:type data_visualize: visualize.data_show type.
:param latent_axes: the axes where the latent visualization should be plotted.
"""
if vals == None:
vals = Model.X[0]
vals = model.X[0]
data_show.__init__(self, vals, axes=latent_axes)
@ -68,13 +68,13 @@ class lvm(data_show):
self.cid = latent_axes[0].figure.canvas.mpl_connect('axes_enter_event', self.on_enter)
self.data_visualize = data_visualize
self.Model = Model
self.Model = model
self.latent_axes = latent_axes
self.sense_axes = sense_axes
self.called = False
self.move_on = False
self.latent_index = latent_index
self.latent_dim = Model.input_dim
self.latent_dim = model.input_dim
# The red cross which shows current latent point.
self.latent_values = vals