mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-04 01:02:39 +02:00
num_data refactoring
This commit is contained in:
parent
efbf169a6a
commit
a6ed003194
3 changed files with 21 additions and 22 deletions
|
|
@ -28,8 +28,7 @@ class GPBase(Model):
|
||||||
self._Xmean = np.zeros((1, self.input_dim))
|
self._Xmean = np.zeros((1, self.input_dim))
|
||||||
self._Xstd = np.ones((1, self.input_dim))
|
self._Xstd = np.ones((1, self.input_dim))
|
||||||
|
|
||||||
Model.__init__(self)
|
super(GPBase, self).__init__()
|
||||||
|
|
||||||
# All leaf nodes should call self._set_params(self._get_params()) at
|
# All leaf nodes should call self._set_params(self._get_params()) at
|
||||||
# the end
|
# the end
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,9 @@ import pylab as pb
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .. import util
|
from .. import util
|
||||||
|
|
||||||
def plot_latent(Model, labels=None, which_indices=None, resolution=50, ax=None, marker='o', s=40):
|
def plot_latent(model, labels=None, which_indices=None, resolution=50, ax=None, marker='o', s=40):
|
||||||
"""
|
"""
|
||||||
:param labels: a np.array of size Model.N containing labels for the points (can be number, strings, etc)
|
:param labels: a np.array of size model.num_data containing labels for the points (can be number, strings, etc)
|
||||||
:param resolution: the resolution of the grid on which to evaluate the predictive variance
|
:param resolution: the resolution of the grid on which to evaluate the predictive variance
|
||||||
"""
|
"""
|
||||||
if ax is None:
|
if ax is None:
|
||||||
|
|
@ -12,26 +12,26 @@ def plot_latent(Model, labels=None, which_indices=None, resolution=50, ax=None,
|
||||||
util.plot.Tango.reset()
|
util.plot.Tango.reset()
|
||||||
|
|
||||||
if labels is None:
|
if labels is None:
|
||||||
labels = np.ones(Model.N)
|
labels = np.ones(model.num_data)
|
||||||
if which_indices is None:
|
if which_indices is None:
|
||||||
if Model.input_dim==1:
|
if model.input_dim==1:
|
||||||
input_1 = 0
|
input_1 = 0
|
||||||
input_2 = None
|
input_2 = None
|
||||||
if Model.input_dim==2:
|
if model.input_dim==2:
|
||||||
input_1, input_2 = 0,1
|
input_1, input_2 = 0,1
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
input_1, input_2 = np.argsort(Model.input_sensitivity())[:2]
|
input_1, input_2 = np.argsort(model.input_sensitivity())[:2]
|
||||||
except:
|
except:
|
||||||
raise ValueError, "cannot Atomatically determine which dimensions to plot, please pass 'which_indices'"
|
raise ValueError, "cannot Atomatically determine which dimensions to plot, please pass 'which_indices'"
|
||||||
else:
|
else:
|
||||||
input_1, input_2 = which_indices
|
input_1, input_2 = which_indices
|
||||||
|
|
||||||
#first, plot the output variance as a function of the latent space
|
#first, plot the output variance as a function of the latent space
|
||||||
Xtest, xx,yy,xmin,xmax = util.plot.x_frame2D(Model.X[:,[input_1, input_2]],resolution=resolution)
|
Xtest, xx,yy,xmin,xmax = util.plot.x_frame2D(model.X[:,[input_1, input_2]],resolution=resolution)
|
||||||
Xtest_full = np.zeros((Xtest.shape[0], Model.X.shape[1]))
|
Xtest_full = np.zeros((Xtest.shape[0], model.X.shape[1]))
|
||||||
Xtest_full[:, :2] = Xtest
|
Xtest_full[:, :2] = Xtest
|
||||||
mu, var, low, up = Model.predict(Xtest_full)
|
mu, var, low, up = model.predict(Xtest_full)
|
||||||
var = var[:, :1]
|
var = var[:, :1]
|
||||||
ax.imshow(var.reshape(resolution, resolution).T,
|
ax.imshow(var.reshape(resolution, resolution).T,
|
||||||
extent=[xmin[0], xmax[0], xmin[1], xmax[1]], cmap=pb.cm.binary,interpolation='bilinear',origin='lower')
|
extent=[xmin[0], xmax[0], xmin[1], xmax[1]], cmap=pb.cm.binary,interpolation='bilinear',origin='lower')
|
||||||
|
|
@ -55,12 +55,12 @@ def plot_latent(Model, labels=None, which_indices=None, resolution=50, ax=None,
|
||||||
m = marker
|
m = marker
|
||||||
|
|
||||||
index = np.nonzero(labels==ul)[0]
|
index = np.nonzero(labels==ul)[0]
|
||||||
if Model.input_dim==1:
|
if model.input_dim==1:
|
||||||
x = Model.X[index,input_1]
|
x = model.X[index,input_1]
|
||||||
y = np.zeros(index.size)
|
y = np.zeros(index.size)
|
||||||
else:
|
else:
|
||||||
x = Model.X[index,input_1]
|
x = model.X[index,input_1]
|
||||||
y = Model.X[index,input_2]
|
y = model.X[index,input_2]
|
||||||
ax.scatter(x, y, marker=m, s=s, color=util.plot.Tango.nextMedium(), label=this_label)
|
ax.scatter(x, y, marker=m, s=s, color=util.plot.Tango.nextMedium(), label=this_label)
|
||||||
|
|
||||||
ax.set_xlabel('latent dimension %i'%input_1)
|
ax.set_xlabel('latent dimension %i'%input_1)
|
||||||
|
|
|
||||||
|
|
@ -43,16 +43,16 @@ class vector_show(data_show):
|
||||||
|
|
||||||
|
|
||||||
class lvm(data_show):
|
class lvm(data_show):
|
||||||
def __init__(self, vals, Model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1]):
|
def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1]):
|
||||||
"""Visualize a latent variable Model
|
"""Visualize a latent variable model
|
||||||
|
|
||||||
:param Model: the latent variable Model to visualize.
|
:param model: the latent variable model to visualize.
|
||||||
:param data_visualize: the object used to visualize the data which has been modelled.
|
:param data_visualize: the object used to visualize the data which has been modelled.
|
||||||
:type data_visualize: visualize.data_show type.
|
:type data_visualize: visualize.data_show type.
|
||||||
:param latent_axes: the axes where the latent visualization should be plotted.
|
:param latent_axes: the axes where the latent visualization should be plotted.
|
||||||
"""
|
"""
|
||||||
if vals == None:
|
if vals == None:
|
||||||
vals = Model.X[0]
|
vals = model.X[0]
|
||||||
|
|
||||||
data_show.__init__(self, vals, axes=latent_axes)
|
data_show.__init__(self, vals, axes=latent_axes)
|
||||||
|
|
||||||
|
|
@ -68,13 +68,13 @@ class lvm(data_show):
|
||||||
self.cid = latent_axes[0].figure.canvas.mpl_connect('axes_enter_event', self.on_enter)
|
self.cid = latent_axes[0].figure.canvas.mpl_connect('axes_enter_event', self.on_enter)
|
||||||
|
|
||||||
self.data_visualize = data_visualize
|
self.data_visualize = data_visualize
|
||||||
self.Model = Model
|
self.Model = model
|
||||||
self.latent_axes = latent_axes
|
self.latent_axes = latent_axes
|
||||||
self.sense_axes = sense_axes
|
self.sense_axes = sense_axes
|
||||||
self.called = False
|
self.called = False
|
||||||
self.move_on = False
|
self.move_on = False
|
||||||
self.latent_index = latent_index
|
self.latent_index = latent_index
|
||||||
self.latent_dim = Model.input_dim
|
self.latent_dim = model.input_dim
|
||||||
|
|
||||||
# The red cross which shows current latent point.
|
# The red cross which shows current latent point.
|
||||||
self.latent_values = vals
|
self.latent_values = vals
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue