pcikling?

This commit is contained in:
Max Zwiessele 2013-06-25 17:43:42 +01:00
parent d3a4f99b89
commit e869fcaf65
9 changed files with 149 additions and 370 deletions

View file

@ -31,6 +31,10 @@ class GP(GPBase):
GPBase.__init__(self, X, likelihood, kernel, normalize_X=normalize_X)
self._set_params(self._get_params())
def __setstate__(self, state):
GPBase.__setstate__(self, state)
self._set_params(self._get_params())
def _set_params(self, p):
self.kern._set_params_transformed(p[:self.kern.num_params_transformed()])
self.likelihood._set_params(p[self.kern.num_params_transformed():])

View file

@ -32,6 +32,32 @@ class GPBase(Model):
# All leaf nodes should call self._set_params(self._get_params()) at
# the end
def __getstate__(self):
"""
Get the current state of the class,
here just all the indices, rest can get recomputed
"""
return Model.__getstate__(self) + [self.X,
self.num_data,
self.input_dim,
self.kern,
self.likelihood,
self.output_dim,
self._Xoffset,
self._Xscale]
def __setstate__(self, state):
self._Xscale = state.pop()
self._Xoffset = state.pop()
self.output_dim = state.pop()
self.likelihood = state.pop()
self.kern = state.pop()
self.input_dim = state.pop()
self.num_data = state.pop()
self.X = state.pop()
Model.__setstate__(self, state)
self._set_params(self._get_params())
def plot_f(self, samples=0, plot_limits=None, which_data='all', which_parts='all', resolution=None, full_cov=False, fignum=None, ax=None):
"""
Plot the GP's view of the world, where the data is normalized and the

View file

@ -32,6 +32,25 @@ class Model(Parameterised):
def _log_likelihood_gradients(self):
raise NotImplementedError, "this needs to be implemented to use the Model class"
def __getstate__(self):
"""
Get the current state of the class,
here just all the indices, rest can get recomputed
"""
return Parameterised.__getstate__(self) + \
[self.priors, self.optimization_runs,
self.sampling_runs, self.preferred_optimizer]
def __setstate__(self, state):
"""
set state from previous call to getstate
"""
self.preferred_optimizer = state.pop()
self.sampling_runs = state.pop()
self.optimization_runs = state.pop()
self.priors = state.pop()
Parameterised.__setstate__(self, state)
def set_prior(self, regexp, what):
"""
Sets priors on the Model parameters.

View file

@ -29,6 +29,24 @@ class Parameterised(object):
"""Returns a (deep) copy of the current model """
return copy.deepcopy(self)
def __getstate__(self):
"""
Get the current state of the class,
here just all the indices, rest can get recomputed
"""
return [self.tied_indices,
self.fixed_indices,
self.fixed_values,
self.constrained_indices,
self.constraints]
def __setstate__(self, state):
self.constraints = state.pop()
self.constrained_indices = state.pop()
self.fixed_values = state.pop()
self.fixed_indices = state.pop()
self.tied_indices = state.pop()
@property
def params(self):
"""

View file

@ -33,10 +33,11 @@ class SparseGP(GPBase):
self.Z = Z
self.num_inducing = Z.shape[0]
self.likelihood = likelihood
# self.likelihood = likelihood
if X_variance is None:
self.has_uncertain_inputs = False
self.X_variance = None
else:
assert X_variance.shape == X.shape
self.has_uncertain_inputs = True
@ -49,6 +50,23 @@ class SparseGP(GPBase):
if self.has_uncertain_inputs:
self.X_variance /= np.square(self._Xscale)
def __getstate__(self):
"""
Get the current state of the class,
here just all the indices, rest can get recomputed
"""
return GPBase.__getstate__(self) + [self.Z,
self.num_inducing,
self.has_uncertain_inputs,
self.X_variance]
def __setstate__(self, state):
self.X_variance = state.pop()
self.has_uncertain_inputs = state.pop()
self.num_inducing = state.pop()
self.Z = state.pop()
GPBase.__setstate__(self, state)
def _compute_kernel_matrices(self):
# kernel computations, using BGPLVM notation
self.Kmm = self.kern.K(self.Z)

View file

@ -45,13 +45,14 @@ class kern(Parameterised):
Parameterised.__init__(self)
def plot_ARD(self, fignum=None, ax=None):
def plot_ARD(self, fignum=None, ax=None, title=None):
"""If an ARD kernel is present, it bar-plots the ARD parameters"""
if ax is None:
fig = pb.figure(fignum)
ax = fig.add_subplot(111)
for p in self.parts:
if hasattr(p, 'ARD') and p.ARD:
if title is None:
ax.set_title('ARD parameters, %s kernel' % p.name)
if p.name == 'linear':

View file

@ -45,32 +45,19 @@ class BayesianGPLVM(SparseGP, GPLVM):
if kernel is None:
kernel = kern.rbf(input_dim) + kern.white(input_dim)
self.oldpsave = oldpsave
self._oldps = []
self._debug = _debug
if self._debug:
self.f_call = 0
self._count = itertools.count()
self._savedklll = []
self._savedparams = []
self._savedgradients = []
self._savederrors = []
self._savedpsiKmm = []
self._savedABCD = []
SparseGP.__init__(self, X, likelihood, kernel, Z=Z, X_variance=X_variance, **kwargs)
self._set_params(self._get_params())
@property
def oldps(self):
return self._oldps
@oldps.setter
def oldps(self, p):
if len(self._oldps) == (self.oldpsave + 1):
self._oldps.pop()
# if len(self._oldps) == 0 or not np.any([np.any(np.abs(p - op) > 1e-5) for op in self._oldps]):
self._oldps.insert(0, p.copy())
def __getstate__(self):
"""
Get the current state of the class,
here just all the indices, rest can get recomputed
"""
return [self.init] + SparseGP.__getstate__(self)
def __setstate__(self, state):
self.init = state.pop()
SparseGP.__setstate__(self, state)
def _get_param_names(self):
X_names = sum([['X_%i_%i' % (n, q) for q in range(self.input_dim)] for n in range(self.num_data)], [])
@ -90,24 +77,11 @@ class BayesianGPLVM(SparseGP, GPLVM):
x = np.hstack((self.X.flatten(), self.X_variance.flatten(), SparseGP._get_params(self)))
return x
def _clipped(self, x):
return x # np.clip(x, -1e300, 1e300)
def _set_params(self, x, save_old=True, save_count=0):
# try:
x = self._clipped(x)
N, input_dim = self.num_data, self.input_dim
self.X = x[:self.X.size].reshape(N, input_dim).copy()
self.X_variance = x[(N * input_dim):(2 * N * input_dim)].reshape(N, input_dim).copy()
SparseGP._set_params(self, x[(2 * N * input_dim):])
# self.oldps = x
# except (LinAlgError, FloatingPointError, ZeroDivisionError):
# print "\rWARNING: Caught LinAlgError, continueing without setting "
# if self._debug:
# self._savederrors.append(self.f_call)
# if save_count > 10:
# raise
# self._set_params(self.oldps[-1], save_old=False, save_count=save_count + 1)
def dKL_dmuS(self):
dKL_dS = (1. - (1. / (self.X_variance))) * 0.5
@ -131,53 +105,16 @@ class BayesianGPLVM(SparseGP, GPLVM):
def log_likelihood(self):
ll = SparseGP.log_likelihood(self)
kl = self.KL_divergence()
# if ll < -2E4:
# ll = -2E4 + np.random.randn()
# if kl > 5E4:
# kl = 5E4 + np.random.randn()
if self._debug:
self.f_call = self._count.next()
if self.f_call % 1 == 0:
self._savedklll.append([self.f_call, ll, kl])
self._savedparams.append([self.f_call, self._get_params()])
self._savedgradients.append([self.f_call, self._log_likelihood_gradients()])
self._savedpsiKmm.append([self.f_call, [self.Kmm, self.dL_dKmm]])
# sf2 = self.scale_factor ** 2
if self.likelihood.is_heteroscedastic:
A = -0.5 * self.num_data * self.input_dim * np.log(2.*np.pi) + 0.5 * np.sum(np.log(self.likelihood.precision)) - 0.5 * np.sum(self.V * self.likelihood.Y)
# B = -0.5 * self.input_dim * (np.sum(self.likelihood.precision.flatten() * self.psi0) - np.trace(self.A) * sf2)
B = -0.5 * self.input_dim * (np.sum(self.likelihood.precision.flatten() * self.psi0) - np.trace(self.A))
else:
A = -0.5 * self.num_data * self.input_dim * (np.log(2.*np.pi) + np.log(self.likelihood._variance)) - 0.5 * self.likelihood.precision * self.likelihood.trYYT
# B = -0.5 * self.input_dim * (np.sum(self.likelihood.precision * self.psi0) - np.trace(self.A) * sf2)
B = -0.5 * self.input_dim * (np.sum(self.likelihood.precision * self.psi0) - np.trace(self.A))
C = -self.input_dim * (np.sum(np.log(np.diag(self.LB)))) # + 0.5 * self.num_inducing * np.log(sf2))
D = 0.5 * np.sum(np.square(self._LBi_Lmi_psi1V))
self._savedABCD.append([self.f_call, A, B, C, D])
# print "\nkl:", kl, "ll:", ll
return ll - kl
def _log_likelihood_gradients(self):
dKL_dmu, dKL_dS = self.dKL_dmuS()
dL_dmu, dL_dS = self.dL_dmuS()
# TODO: find way to make faster
d_dmu = (dL_dmu - dKL_dmu).flatten()
d_dS = (dL_dS - dKL_dS).flatten()
# TEST KL: ====================
# d_dmu = (dKL_dmu).flatten()
# d_dS = (dKL_dS).flatten()
# ========================
# TEST L: ====================
# d_dmu = (dL_dmu).flatten()
# d_dS = (dL_dS).flatten()
# ========================
self.dbound_dmuS = np.hstack((d_dmu, d_dS))
self.dbound_dZtheta = SparseGP._log_likelihood_gradients(self)
return self._clipped(np.hstack((self.dbound_dmuS.flatten(), self.dbound_dZtheta)))
return np.hstack((self.dbound_dmuS.flatten(), self.dbound_dZtheta))
def plot_latent(self, *args, **kwargs):
return plot_latent.plot_latent_indices(self, *args, **kwargs)
@ -256,275 +193,6 @@ class BayesianGPLVM(SparseGP, GPLVM):
fig.tight_layout(h_pad=.01) # , rect=(0, 0, 1, .95))
return fig
def __getstate__(self):
return (self.likelihood, self.input_dim, self.X, self.X_variance,
self.init, self.num_inducing, self.Z, self.kern,
self.oldpsave, self._debug)
def __setstate__(self, state):
self.__init__(*state)
def _debug_filter_params(self, x):
start, end = 0, self.X.size,
X = x[start:end].reshape(self.num_data, self.input_dim)
start, end = end, end + self.X_variance.size
X_v = x[start:end].reshape(self.num_data, self.input_dim)
start, end = end, end + (self.num_inducing * self.input_dim)
Z = x[start:end].reshape(self.num_inducing, self.input_dim)
start, end = end, end + self.input_dim
theta = x[start:]
return X, X_v, Z, theta
def _debug_get_axis(self, figs):
if figs[-1].axes:
ax1 = figs[-1].axes[0]
ax1.cla()
else:
ax1 = figs[-1].add_subplot(111)
return ax1
def _debug_plot(self):
assert self._debug, "must enable _debug, to debug-plot"
import pylab
# from mpl_toolkits.mplot3d import Axes3D
figs = [pylab.figure('BGPLVM DEBUG', figsize=(12, 4))]
# fig.clf()
# log like
# splotshape = (6, 4)
# ax1 = pylab.subplot2grid(splotshape, (0, 0), 1, 4)
ax1 = self._debug_get_axis(figs)
ax1.text(.5, .5, "Optimization", alpha=.3, transform=ax1.transAxes,
ha='center', va='center')
kllls = np.array(self._savedklll)
LL, = ax1.plot(kllls[:, 0], kllls[:, 1] - kllls[:, 2], '-', label=r'$\log p(\mathbf{Y})$', mew=1.5)
KL, = ax1.plot(kllls[:, 0], kllls[:, 2], '-', label=r'$\mathcal{KL}(p||q)$', mew=1.5)
L, = ax1.plot(kllls[:, 0], kllls[:, 1], '-', label=r'$L$', mew=1.5) # \mathds{E}_{q(\mathbf{X})}[p(\mathbf{Y|X})\frac{p(\mathbf{X})}{q(\mathbf{X})}]
param_dict = dict(self._savedparams)
gradient_dict = dict(self._savedgradients)
# kmm_dict = dict(self._savedpsiKmm)
iters = np.array(param_dict.keys())
ABCD_dict = np.array(self._savedABCD)
self.showing = 0
# ax2 = pylab.subplot2grid(splotshape, (1, 0), 2, 4)
figs.append(pylab.figure("BGPLVM DEBUG X", figsize=(12, 4)))
ax2 = self._debug_get_axis(figs)
ax2.text(.5, .5, r"$\mathbf{X}$", alpha=.5, transform=ax2.transAxes,
ha='center', va='center')
figs[-1].canvas.draw()
figs[-1].tight_layout(rect=(0, 0, 1, .86))
# ax3 = pylab.subplot2grid(splotshape, (3, 0), 2, 4, sharex=ax2)
figs.append(pylab.figure("BGPLVM DEBUG S", figsize=(12, 4)))
ax3 = self._debug_get_axis(figs)
ax3.text(.5, .5, r"$\mathbf{S}$", alpha=.5, transform=ax3.transAxes,
ha='center', va='center')
figs[-1].canvas.draw()
figs[-1].tight_layout(rect=(0, 0, 1, .86))
# ax4 = pylab.subplot2grid(splotshape, (5, 0), 2, 2)
figs.append(pylab.figure("BGPLVM DEBUG Z", figsize=(6, 4)))
ax4 = self._debug_get_axis(figs)
ax4.text(.5, .5, r"$\mathbf{Z}$", alpha=.5, transform=ax4.transAxes,
ha='center', va='center')
figs[-1].canvas.draw()
figs[-1].tight_layout(rect=(0, 0, 1, .86))
# ax5 = pylab.subplot2grid(splotshape, (5, 2), 2, 2)
figs.append(pylab.figure("BGPLVM DEBUG theta", figsize=(6, 4)))
ax5 = self._debug_get_axis(figs)
ax5.text(.5, .5, r"${\theta}$", alpha=.5, transform=ax5.transAxes,
ha='center', va='center')
figs[-1].canvas.draw()
figs[-1].tight_layout(rect=(.15, 0, 1, .86))
# figs.append(pylab.figure("BGPLVM DEBUG Kmm", figsize=(12, 6)))
# fig = figs[-1]
# ax6 = fig.add_subplot(121)
# ax6.text(.5, .5, r"${\mathbf{K}_{mm}}$", color='magenta', alpha=.5, transform=ax6.transAxes,
# ha='center', va='center')
# ax7 = fig.add_subplot(122)
# ax7.text(.5, .5, r"${\frac{dL}{dK_{mm}}}$", color='magenta', alpha=.5, transform=ax7.transAxes,
# ha='center', va='center')
figs.append(pylab.figure("BGPLVM DEBUG Kmm", figsize=(12, 6)))
fig = figs[-1]
ax8 = fig.add_subplot(121)
ax8.text(.5, .5, r"${\mathbf{A,B,C,input_dim}}$", color='k', alpha=.5, transform=ax8.transAxes,
ha='center', va='center')
ax8.plot(ABCD_dict[:, 0], ABCD_dict[:, 1], label='A')
ax8.plot(ABCD_dict[:, 0], ABCD_dict[:, 2], label='B')
ax8.plot(ABCD_dict[:, 0], ABCD_dict[:, 3], label='C')
ax8.plot(ABCD_dict[:, 0], ABCD_dict[:, 4], label='input_dim')
ax8.legend()
figs[-1].canvas.draw()
figs[-1].tight_layout(rect=(.15, 0, 1, .86))
X, S, Z, theta = self._debug_filter_params(param_dict[self.showing])
Xg, Sg, Zg, thetag = self._debug_filter_params(gradient_dict[self.showing])
# Xg, Sg, Zg, thetag = -Xg, -Sg, -Zg, -thetag
quiver_units = 'xy'
quiver_scale = 1
quiver_scale_units = 'xy'
Xlatentplts = ax2.plot(X, ls="-", marker="x")
colors = colorConverter.to_rgba_array([p.get_color() for p in Xlatentplts], .4)
Ulatent = np.zeros_like(X)
xlatent = np.tile(np.arange(0, X.shape[0])[:, None], X.shape[1])
Xlatentgrads = ax2.quiver(xlatent, X, Ulatent, Xg, color=colors,
units=quiver_units, scale_units=quiver_scale_units,
scale=quiver_scale)
Slatentplts = ax3.plot(S, ls="-", marker="x")
Slatentgrads = ax3.quiver(xlatent, S, Ulatent, Sg, color=colors,
units=quiver_units, scale_units=quiver_scale_units,
scale=quiver_scale)
ax3.set_ylim(0, 1.)
xZ = np.tile(np.arange(0, Z.shape[0])[:, None], Z.shape[1])
UZ = np.zeros_like(Z)
Zplts = ax4.plot(Z, ls="-", marker="x")
Zgrads = ax4.quiver(xZ, Z, UZ, Zg, color=colors,
units=quiver_units, scale_units=quiver_scale_units,
scale=quiver_scale)
xtheta = np.arange(len(theta))
Utheta = np.zeros_like(theta)
thetaplts = ax5.bar(xtheta - .4, theta, color=colors)
thetagrads = ax5.quiver(xtheta, theta, Utheta, thetag, color=colors,
units=quiver_units, scale_units=quiver_scale_units,
scale=quiver_scale,
edgecolors=('k',), linewidths=[1])
pylab.setp(thetaplts, zorder=0)
pylab.setp(thetagrads, zorder=10)
ax5.set_xticks(np.arange(len(theta)))
ax5.set_xticklabels(self._get_param_names()[-len(theta):], rotation=17)
# imkmm = ax6.imshow(kmm_dict[self.showing][0])
# from mpl_toolkits.axes_grid1 import make_axes_locatable
# divider = make_axes_locatable(ax6)
# caxkmm = divider.append_axes("right", "5%", pad="1%")
# cbarkmm = pylab.colorbar(imkmm, cax=caxkmm)
#
# imkmmdl = ax7.imshow(kmm_dict[self.showing][1])
# divider = make_axes_locatable(ax7)
# caxkmmdl = divider.append_axes("right", "5%", pad="1%")
# cbarkmmdl = pylab.colorbar(imkmmdl, cax=caxkmmdl)
# input_dimleg = ax1.legend(Xlatentplts, [r"$input_dim_{}$".format(i + 1) for i in range(self.input_dim)],
# loc=3, ncol=self.input_dim, bbox_to_anchor=(0, 1.15, 1, 1.15),
# borderaxespad=0, mode="expand")
ax2.legend(Xlatentplts, [r"$input_dim_{}$".format(i + 1) for i in range(self.input_dim)],
loc=3, ncol=self.input_dim, bbox_to_anchor=(0, 1.1, 1, 1.1),
borderaxespad=0, mode="expand")
ax3.legend(Xlatentplts, [r"$input_dim_{}$".format(i + 1) for i in range(self.input_dim)],
loc=3, ncol=self.input_dim, bbox_to_anchor=(0, 1.1, 1, 1.1),
borderaxespad=0, mode="expand")
ax4.legend(Xlatentplts, [r"$input_dim_{}$".format(i + 1) for i in range(self.input_dim)],
loc=3, ncol=self.input_dim, bbox_to_anchor=(0, 1.1, 1, 1.1),
borderaxespad=0, mode="expand")
ax5.legend(Xlatentplts, [r"$input_dim_{}$".format(i + 1) for i in range(self.input_dim)],
loc=3, ncol=self.input_dim, bbox_to_anchor=(0, 1.1, 1, 1.1),
borderaxespad=0, mode="expand")
Lleg = ax1.legend()
Lleg.draggable()
# ax1.add_artist(input_dimleg)
indicatorKL, = ax1.plot(kllls[self.showing, 0], kllls[self.showing, 2], 'o', c=KL.get_color())
indicatorLL, = ax1.plot(kllls[self.showing, 0], kllls[self.showing, 1] - kllls[self.showing, 2], 'o', c=LL.get_color())
indicatorL, = ax1.plot(kllls[self.showing, 0], kllls[self.showing, 1], 'o', c=L.get_color())
# for err in self._savederrors:
# if err < kllls.shape[0]:
# ax1.scatter(kllls[err, 0], kllls[err, 2], s=50, marker=(5, 2), c=KL.get_color())
# ax1.scatter(kllls[err, 0], kllls[err, 1] - kllls[err, 2], s=50, marker=(5, 2), c=LL.get_color())
# ax1.scatter(kllls[err, 0], kllls[err, 1], s=50, marker=(5, 2), c=L.get_color())
# try:
# for f in figs:
# f.canvas.draw()
# f.tight_layout(box=(0, .15, 1, .9))
# # pylab.draw()
# # pylab.tight_layout(box=(0, .1, 1, .9))
# except:
# pass
# parameter changes
# ax2 = pylab.subplot2grid((4, 1), (1, 0), 3, 1, projection='3d')
button_options = [0, 0] # [0]: clicked -- [1]: dragged
def update_plots(event):
if button_options[0] and not button_options[1]:
# event.button, event.x, event.y, event.xdata, event.ydata)
tmp = np.abs(iters - event.xdata)
closest_hit = iters[tmp == tmp.min()][0]
if closest_hit != self.showing:
self.showing = closest_hit
# print closest_hit, iters, event.xdata
indicatorLL.set_data(self.showing, kllls[self.showing, 1] - kllls[self.showing, 2])
indicatorKL.set_data(self.showing, kllls[self.showing, 2])
indicatorL.set_data(self.showing, kllls[self.showing, 1])
X, S, Z, theta = self._debug_filter_params(param_dict[self.showing])
Xg, Sg, Zg, thetag = self._debug_filter_params(gradient_dict[self.showing])
# Xg, Sg, Zg, thetag = -Xg, -Sg, -Zg, -thetag
for i, Xlatent in enumerate(Xlatentplts):
Xlatent.set_ydata(X[:, i])
Xlatentgrads.set_offsets(np.array([xlatent.ravel(), X.ravel()]).T)
Xlatentgrads.set_UVC(Ulatent, Xg)
for i, Slatent in enumerate(Slatentplts):
Slatent.set_ydata(S[:, i])
Slatentgrads.set_offsets(np.array([xlatent.ravel(), S.ravel()]).T)
Slatentgrads.set_UVC(Ulatent, Sg)
for i, Zlatent in enumerate(Zplts):
Zlatent.set_ydata(Z[:, i])
Zgrads.set_offsets(np.array([xZ.ravel(), Z.ravel()]).T)
Zgrads.set_UVC(UZ, Zg)
for p, t in zip(thetaplts, theta):
p.set_height(t)
thetagrads.set_offsets(np.array([xtheta.ravel(), theta.ravel()]).T)
thetagrads.set_UVC(Utheta, thetag)
# imkmm.set_data(kmm_dict[self.showing][0])
# imkmm.autoscale()
# cbarkmm.update_normal(imkmm)
#
# imkmmdl.set_data(kmm_dict[self.showing][1])
# imkmmdl.autoscale()
# cbarkmmdl.update_normal(imkmmdl)
ax2.relim()
# ax3.relim()
ax4.relim()
ax5.relim()
ax2.autoscale()
# ax3.autoscale()
ax4.autoscale()
ax5.autoscale()
[fig.canvas.draw() for fig in figs]
button_options[0] = 0
button_options[1] = 0
def onclick(event):
if event.inaxes is ax1 and event.button == 1:
button_options[0] = 1
def motion(event):
if button_options[0]:
button_options[1] = 1
cidr = figs[0].canvas.mpl_connect('button_release_event', update_plots)
cidp = figs[0].canvas.mpl_connect('button_press_event', onclick)
cidd = figs[0].canvas.mpl_connect('motion_notify_event', motion)
return ax1, ax2, ax3, ax4, ax5 # , ax6, ax7
def latent_cost_and_grad(mu_S, kern, Z, dL_dpsi0, dL_dpsi1, dL_dpsi2):
"""
objective function for fitting the latent variables for test points

View file

@ -61,12 +61,14 @@ class MRD(Model):
assert not ('kernel' in kw), "pass kernels through `kernels` argument"
self.input_dim = input_dim
self.num_inducing = num_inducing
self._debug = _debug
self.num_inducing = num_inducing
self._init = True
X = self._init_X(initx, likelihood_or_Y_list)
Z = self._init_Z(initz, X)
self.num_inducing = Z.shape[0] # ensure M==N if M>N
self.bgplvms = [BayesianGPLVM(l, input_dim=input_dim, kernel=k, X=X, Z=Z, num_inducing=self.num_inducing, **kw) for l, k in zip(likelihood_or_Y_list, kernels)]
del self._init
@ -75,12 +77,35 @@ class MRD(Model):
self.nparams = nparams.cumsum()
self.num_data = self.gref.num_data
self.NQ = self.num_data * self.input_dim
self.MQ = self.num_inducing * self.input_dim
Model.__init__(self)
self._set_params(self._get_params())
def __getstate__(self):
return [self.names,
self.bgplvms,
self.gref,
self.nparams,
self.input_dim,
self.num_inducing,
self.num_data,
self.NQ,
self.MQ]
def __setstate__(self, state):
self.MQ = state.pop()
self.NQ = state.pop()
self.num_data = state.pop()
self.num_inducing = state.pop()
self.input_dim = state.pop()
self.nparams = state.pop()
self.gref = state.pop()
self.bgplvms = state.pop()
self.names = state.pop()
@property
def X(self):
return self.gref.X
@ -257,7 +282,7 @@ class MRD(Model):
def _handle_plotting(self, fignum, axes, plotf):
if axes is None:
fig = pylab.figure(num=fignum, figsize=(4 * len(self.bgplvms), 3))
fig = pylab.figure(num=fignum)
for i, g in enumerate(self.bgplvms):
if axes is None:
ax = fig.add_subplot(1, len(self.bgplvms), i + 1)
@ -285,11 +310,11 @@ class MRD(Model):
return fig
def plot_scales(self, fignum=None, ax=None, *args, **kwargs):
fig = self._handle_plotting(fignum, ax, lambda i, g, ax: g.kern.plot_ARD(ax=ax, *args, **kwargs))
fig = self._handle_plotting(fignum, ax, lambda i, g, ax: g.kern.plot_ARD(ax=ax, title=r'$Y_{}$'.format(i), *args, **kwargs))
return fig
def plot_latent(self, fignum=None, ax=None, *args, **kwargs):
fig = self._handle_plotting(fignum, ax, lambda i, g, ax: g.plot_latent(ax=ax, *args, **kwargs))
fig = self.gref.plot_X_1d(*args, **kwargs) # self._handle_plotting(fignum, ax, lambda i, g, ax: g.plot_latent(ax=ax, *args, **kwargs))
return fig
def _debug_plot(self):