mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-03 00:32:39 +02:00
num_data refactoring
This commit is contained in:
parent
efbf169a6a
commit
a6ed003194
3 changed files with 21 additions and 22 deletions
|
|
@ -28,8 +28,7 @@ class GPBase(Model):
|
|||
self._Xmean = np.zeros((1, self.input_dim))
|
||||
self._Xstd = np.ones((1, self.input_dim))
|
||||
|
||||
Model.__init__(self)
|
||||
|
||||
super(GPBase, self).__init__()
|
||||
# All leaf nodes should call self._set_params(self._get_params()) at
|
||||
# the end
|
||||
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ import pylab as pb
|
|||
import numpy as np
|
||||
from .. import util
|
||||
|
||||
def plot_latent(Model, labels=None, which_indices=None, resolution=50, ax=None, marker='o', s=40):
|
||||
def plot_latent(model, labels=None, which_indices=None, resolution=50, ax=None, marker='o', s=40):
|
||||
"""
|
||||
:param labels: a np.array of size Model.N containing labels for the points (can be number, strings, etc)
|
||||
:param labels: a np.array of size model.num_data containing labels for the points (can be number, strings, etc)
|
||||
:param resolution: the resolution of the grid on which to evaluate the predictive variance
|
||||
"""
|
||||
if ax is None:
|
||||
|
|
@ -12,26 +12,26 @@ def plot_latent(Model, labels=None, which_indices=None, resolution=50, ax=None,
|
|||
util.plot.Tango.reset()
|
||||
|
||||
if labels is None:
|
||||
labels = np.ones(Model.N)
|
||||
labels = np.ones(model.num_data)
|
||||
if which_indices is None:
|
||||
if Model.input_dim==1:
|
||||
if model.input_dim==1:
|
||||
input_1 = 0
|
||||
input_2 = None
|
||||
if Model.input_dim==2:
|
||||
if model.input_dim==2:
|
||||
input_1, input_2 = 0,1
|
||||
else:
|
||||
try:
|
||||
input_1, input_2 = np.argsort(Model.input_sensitivity())[:2]
|
||||
input_1, input_2 = np.argsort(model.input_sensitivity())[:2]
|
||||
except:
|
||||
raise ValueError, "cannot Atomatically determine which dimensions to plot, please pass 'which_indices'"
|
||||
else:
|
||||
input_1, input_2 = which_indices
|
||||
|
||||
#first, plot the output variance as a function of the latent space
|
||||
Xtest, xx,yy,xmin,xmax = util.plot.x_frame2D(Model.X[:,[input_1, input_2]],resolution=resolution)
|
||||
Xtest_full = np.zeros((Xtest.shape[0], Model.X.shape[1]))
|
||||
Xtest, xx,yy,xmin,xmax = util.plot.x_frame2D(model.X[:,[input_1, input_2]],resolution=resolution)
|
||||
Xtest_full = np.zeros((Xtest.shape[0], model.X.shape[1]))
|
||||
Xtest_full[:, :2] = Xtest
|
||||
mu, var, low, up = Model.predict(Xtest_full)
|
||||
mu, var, low, up = model.predict(Xtest_full)
|
||||
var = var[:, :1]
|
||||
ax.imshow(var.reshape(resolution, resolution).T,
|
||||
extent=[xmin[0], xmax[0], xmin[1], xmax[1]], cmap=pb.cm.binary,interpolation='bilinear',origin='lower')
|
||||
|
|
@ -55,12 +55,12 @@ def plot_latent(Model, labels=None, which_indices=None, resolution=50, ax=None,
|
|||
m = marker
|
||||
|
||||
index = np.nonzero(labels==ul)[0]
|
||||
if Model.input_dim==1:
|
||||
x = Model.X[index,input_1]
|
||||
if model.input_dim==1:
|
||||
x = model.X[index,input_1]
|
||||
y = np.zeros(index.size)
|
||||
else:
|
||||
x = Model.X[index,input_1]
|
||||
y = Model.X[index,input_2]
|
||||
x = model.X[index,input_1]
|
||||
y = model.X[index,input_2]
|
||||
ax.scatter(x, y, marker=m, s=s, color=util.plot.Tango.nextMedium(), label=this_label)
|
||||
|
||||
ax.set_xlabel('latent dimension %i'%input_1)
|
||||
|
|
@ -88,4 +88,4 @@ def plot_latent_indices(Model, which_indices=None, *args, **kwargs):
|
|||
ax = plot_latent(Model, which_indices=[input_1, input_2], *args, **kwargs)
|
||||
# TODO: Here test if there are inducing points...
|
||||
ax.plot(Model.Z[:, input_1], Model.Z[:, input_2], '^w')
|
||||
return ax
|
||||
return ax
|
||||
|
|
|
|||
|
|
@ -43,16 +43,16 @@ class vector_show(data_show):
|
|||
|
||||
|
||||
class lvm(data_show):
|
||||
def __init__(self, vals, Model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1]):
|
||||
"""Visualize a latent variable Model
|
||||
def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1]):
|
||||
"""Visualize a latent variable model
|
||||
|
||||
:param Model: the latent variable Model to visualize.
|
||||
:param model: the latent variable model to visualize.
|
||||
:param data_visualize: the object used to visualize the data which has been modelled.
|
||||
:type data_visualize: visualize.data_show type.
|
||||
:param latent_axes: the axes where the latent visualization should be plotted.
|
||||
"""
|
||||
if vals == None:
|
||||
vals = Model.X[0]
|
||||
vals = model.X[0]
|
||||
|
||||
data_show.__init__(self, vals, axes=latent_axes)
|
||||
|
||||
|
|
@ -68,13 +68,13 @@ class lvm(data_show):
|
|||
self.cid = latent_axes[0].figure.canvas.mpl_connect('axes_enter_event', self.on_enter)
|
||||
|
||||
self.data_visualize = data_visualize
|
||||
self.Model = Model
|
||||
self.Model = model
|
||||
self.latent_axes = latent_axes
|
||||
self.sense_axes = sense_axes
|
||||
self.called = False
|
||||
self.move_on = False
|
||||
self.latent_index = latent_index
|
||||
self.latent_dim = Model.input_dim
|
||||
self.latent_dim = model.input_dim
|
||||
|
||||
# The red cross which shows current latent point.
|
||||
self.latent_values = vals
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue