From 264f0d21b61ec0ec964fa4df9e33171af40dcfac Mon Sep 17 00:00:00 2001 From: Max Zwiessele Date: Tue, 23 Apr 2013 16:34:31 +0100 Subject: [PATCH] kern stash conflict --- GPy/examples/dimensionality_reduction.py | 84 +++++++---- GPy/inference/natural_gradient_scg.py | 146 +++++++++++++++++++ GPy/models/Bayesian_GPLVM.py | 178 ++++++++++++++++++++++- GPy/models/mrd.py | 23 --- 4 files changed, 370 insertions(+), 61 deletions(-) create mode 100644 GPy/inference/natural_gradient_scg.py diff --git a/GPy/examples/dimensionality_reduction.py b/GPy/examples/dimensionality_reduction.py index 8c8e23fe..e5f50237 100644 --- a/GPy/examples/dimensionality_reduction.py +++ b/GPy/examples/dimensionality_reduction.py @@ -112,14 +112,14 @@ def _simulate_sincos(D1, D2, D3, N, M, Q, plot_sim=False): s3 = s3(x) sS = sS(x) - s1 -= s1.mean() - s2 -= s2.mean() - s3 -= s3.mean() - sS -= sS.mean() - s1 /= .5 * (np.abs(s1).max() - np.abs(s1).min()) - s2 /= .5 * (np.abs(s2).max() - np.abs(s2).min()) - s3 /= .5 * (np.abs(s3).max() - np.abs(s3).min()) - sS /= .5 * (np.abs(sS).max() - np.abs(sS).min()) +# s1 -= s1.mean() +# s2 -= s2.mean() +# s3 -= s3.mean() +# sS -= sS.mean() +# s1 /= .5 * (np.abs(s1).max() - np.abs(s1).min()) +# s2 /= .5 * (np.abs(s2).max() - np.abs(s2).min()) +# s3 /= .5 * (np.abs(s3).max() - np.abs(s3).min()) +# sS /= .5 * (np.abs(sS).max() - np.abs(sS).min()) S1 = np.hstack([s1, sS]) S2 = np.hstack([s2, sS]) @@ -129,9 +129,9 @@ def _simulate_sincos(D1, D2, D3, N, M, Q, plot_sim=False): Y2 = S2.dot(np.random.randn(S2.shape[1], D2)) Y3 = S3.dot(np.random.randn(S3.shape[1], D3)) - Y1 += .5 * np.random.randn(*Y1.shape) - Y2 += .5 * np.random.randn(*Y2.shape) - Y3 += .5 * np.random.randn(*Y3.shape) + Y1 += .3 * np.random.randn(*Y1.shape) + Y2 += .3 * np.random.randn(*Y2.shape) + Y3 += .3 * np.random.randn(*Y3.shape) Y1 -= Y1.mean(0) Y2 -= Y2.mean(0) @@ -162,8 +162,11 @@ def _simulate_sincos(D1, D2, D3, N, M, Q, plot_sim=False): return slist, [S1, S2, S3], Ylist -def bgplvm_simulation(burnin='scg', plot_sim=False, max_f_eval=12): - D1, D2, D3, N, M, Q = 2000, 8, 8, 500, 2, 6 +def bgplvm_simulation(burnin='scg', plot_sim=False, + max_burnin=100, true_X=False, + do_opt=True, + max_f_eval=1000): + D1, D2, D3, N, M, Q = 10, 8, 8, 50, 30, 5 slist, Slist, Ylist = _simulate_sincos(D1, D2, D3, N, M, Q, plot_sim) from GPy.models import mrd @@ -171,53 +174,73 @@ def bgplvm_simulation(burnin='scg', plot_sim=False, max_f_eval=12): reload(mrd); reload(kern) - Y = Ylist[1] + Y = Ylist[0] k = kern.linear(Q, ARD=True) + kern.white(Q, .00001) # + kern.bias(Q) - m = Bayesian_GPLVM(Y, Q, init="PCA", M=M, kernel=k) +# k = kern.white(Q, .00001) + kern.bias(Q) + m = Bayesian_GPLVM(Y, Q, init="PCA", M=M, kernel=k, _debug=True) # m.set('noise',) + m.ensure_default_constraints() # m.auto_scale_factor = True # m.scale_factor = 1. - m.ensure_default_constraints() if burnin: print "initializing beta" cstr = "noise" - m.unconstrain(cstr); m.constrain_fixed(cstr, Y.var() / 100.) - m.optimize(burnin, messages=1, max_f_eval=max_f_eval) + m.unconstrain(cstr); m.constrain_fixed(cstr, Y.var() / 70.) + m.optimize(burnin, messages=1, max_f_eval=max_burnin) print "releasing beta" cstr = "noise" m.unconstrain(cstr); m.constrain_positive(cstr) - true_X = np.hstack((slist[1], slist[3], 0. * np.ones((N, Q - 2)))) - m.set('X_\d', true_X) - m.constrain_fixed("X_\d") + if true_X: + true_X = np.hstack((slist[0], slist[3], 0. * np.ones((N, Q - 2)))) + m.set('X_\d', true_X) + m.constrain_fixed("X_\d") -# # cstr = 'variance' -# # m.unconstrain(cstr), m.constrain_bounded(cstr, 1e-10, 1.) + cstr = 'X_variance' +# m.unconstrain(cstr), m.constrain_fixed(cstr, .0001) + m.unconstrain(cstr), m.constrain_bounded(cstr, 1e-7, .1) + +# cstr = 'X_variance' +# m.unconstrain(cstr), m.constrain_bounded(cstr, 1e-3, 1.) + + m.set('X_var', np.ones(N * Q) * .5 + np.random.randn(N * Q) * .01) + +# cstr = "iip" +# m.unconstrain(cstr); m.constrain_fixed(cstr) + +# cstr = 'variance' +# m.unconstrain(cstr), m.constrain_bounded(cstr, 1e-10, 1.) # cstr = 'X_\d' -# m.unconstrain(cstr), m.constrain_bounded(cstr, -100., 100.) +# m.unconstrain(cstr), m.constrain_bounded(cstr, -10., 10.) # # cstr = 'noise' -# m.unconstrain(cstr), m.constrain_bounded(cstr, 1e-3, 1.) +# m.unconstrain(cstr), m.constrain_bounded(cstr, 1e-5, 1.) # # cstr = 'white' # m.unconstrain(cstr), m.constrain_bounded(cstr, 1e-6, 1.) # # cstr = 'linear_variance' -# m.unconstrain(cstr), m.constrain_bounded(cstr, 1e-10, 10.) # m.constrain_positive(cstr) -# -# cstr = 'X_variance' -# m.unconstrain(cstr), m.constrain_bounded(cstr, 1e-10, 1.) # m.constrain_positive(cstr) +# m.unconstrain(cstr), m.constrain_bounded(cstr, 1e-10, 10.) + +# cstr = 'variance' +# m.unconstrain(cstr), m.constrain_bounded(cstr, 1e-10, 10.) # np.seterr(all='call') # def ipdbonerr(errtype, flags): # import ipdb; ipdb.set_trace() # np.seterrcall(ipdbonerr) - + if do_opt and burnin: + try: + m.optimize(burnin, messages=1, max_f_eval=max_f_eval) + except: + pass + finally: + return m return m def mrd_simulation(plot_sim=False): @@ -261,6 +284,7 @@ def mrd_simulation(plot_sim=False): m.set('{}_noise'.format(i + 1), Y.var() / 100.) m.ensure_default_constraints() + m.auto_scale_factor = True # cstr = 'variance' # m.unconstrain(cstr), m.constrain_bounded(cstr, 1e-12, 1.) diff --git a/GPy/inference/natural_gradient_scg.py b/GPy/inference/natural_gradient_scg.py new file mode 100644 index 00000000..ca42acfe --- /dev/null +++ b/GPy/inference/natural_gradient_scg.py @@ -0,0 +1,146 @@ +#Copyright I. Nabney, N.Lawrence and James Hensman (1996 - 2012) + +#Scaled Conjuagte Gradients, originally in Matlab as part of the Netlab toolbox by I. Nabney, converted to python N. Lawrence and given a pythonic interface by James Hensman + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT +# HOLDERS AND CONTRIBUTORS "AS IS" AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# REGENTS OR CONTRIBUTORS BE LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT +# OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + + +import numpy as np +import sys + +def SCG(f, gradf, x, optargs=(), maxiters=500, max_f_eval=500, display=True, xtol=1e-6, ftol=1e-6): + """ + Optimisation through Scaled Conjugate Gradients (SCG) + + f: the objective function + gradf : the gradient function (should return a 1D np.ndarray) + x : the initial condition + + Returns + x the optimal value for x + flog : a list of all the objective values + + """ + + sigma0 = 1.0e-4 + fold = f(x, *optargs) # Initial function value. + function_eval = 1 + fnow = fold + gradnew = gradf(x, *optargs) # Initial gradient. + gradold = gradnew.copy() + d = -gradnew # Initial search direction. + success = True # Force calculation of directional derivs. + nsuccess = 0 # nsuccess counts number of successes. + beta = 1.0 # Initial scale parameter. + betamin = 1.0e-15 # Lower bound on scale. + betamax = 1.0e100 # Upper bound on scale. + status = "Not converged" + + flog = [fold] + + iteration = 0 + + # Main optimization loop. + while iteration < maxiters: + + # Calculate first and second directional derivatives. + if success: + mu = np.dot(d, gradnew) + if mu >= 0: + d = -gradnew + mu = np.dot(d, gradnew) + kappa = np.dot(d, d) + sigma = sigma0/np.sqrt(kappa) + xplus = x + sigma*d + gplus = gradf(xplus, *optargs) + theta = np.dot(d, (gplus - gradnew))/sigma + + # Increase effective curvature and evaluate step size alpha. + delta = theta + beta*kappa + if delta <= 0: + delta = beta*kappa + beta = beta - theta/kappa + + alpha = - mu/delta + + # Calculate the comparison ratio. + xnew = x + alpha*d + fnew = f(xnew, *optargs) + function_eval += 1 + + if function_eval >= max_f_eval: + status = "Maximum number of function evaluations exceeded" + return x, flog, function_eval, status + + Delta = 2.*(fnew - fold)/(alpha*mu) + if Delta >= 0.: + success = True + nsuccess += 1 + x = xnew + fnow = fnew + else: + success = False + fnow = fold + + # Store relevant variables + flog.append(fnow) # Current function value + + iteration += 1 + if display: + print '\r', + print 'Iteration: {0:>5g} Objective:{1:> 12e} Scale:{2:> 12e}'.format(iteration, fnow, beta), + # print 'Iteration:', iteration, ' Objective:', fnow, ' Scale:', beta, '\r', + sys.stdout.flush() + + if success: + # Test for termination + if (np.max(np.abs(alpha*d)) < xtol) or (np.abs(fnew-fold) < ftol): + status='converged' + return x, flog, function_eval, status + + else: + # Update variables for new position + fold = fnew + gradold = gradnew + gradnew = gradf(x, *optargs) + # If the gradient is zero then we are done. + if np.dot(gradnew,gradnew) == 0: + return x, flog, function_eval, status + + # Adjust beta according to comparison ratio. + if Delta < 0.25: + beta = min(4.0*beta, betamax) + if Delta > 0.75: + beta = max(0.5*beta, betamin) + + # Update search direction using Polak-Ribiere formula, or re-start + # in direction of negative gradient after nparams steps. + if nsuccess == x.size: + d = -gradnew + nsuccess = 0 + elif success: + gamma = np.dot(gradold - gradnew,gradnew)/(mu) + d = gamma*d - gradnew + + # If we get here, then we haven't terminated in the given number of + # iterations. + status = "maxiter exceeded" + + return x, flog, function_eval, status diff --git a/GPy/models/Bayesian_GPLVM.py b/GPy/models/Bayesian_GPLVM.py index a23368de..0646b25f 100644 --- a/GPy/models/Bayesian_GPLVM.py +++ b/GPy/models/Bayesian_GPLVM.py @@ -10,6 +10,7 @@ from GPy.util.linalg import pdinv from ..likelihoods import Gaussian from .. import kern from numpy.linalg.linalg import LinAlgError +import itertools class Bayesian_GPLVM(sparse_GP, GPLVM): """ @@ -23,7 +24,9 @@ class Bayesian_GPLVM(sparse_GP, GPLVM): :type init: 'PCA'|'random' """ - def __init__(self, Y, Q, X=None, X_variance=None, init='PCA', M=10, Z=None, kernel=None, oldpsave=5, **kwargs): + def __init__(self, Y, Q, X=None, X_variance=None, init='PCA', M=10, + Z=None, kernel=None, oldpsave=5, _debug=False, + **kwargs): if X == None: X = self.initialise_latent(init, Q, Y) @@ -39,6 +42,12 @@ class Bayesian_GPLVM(sparse_GP, GPLVM): self.oldpsave = oldpsave self._oldps = [] + self._debug = _debug + + if self._debug: + self._count = itertools.count() + self._savedklll = [] + self._savedparams = [] sparse_GP.__init__(self, X, Gaussian(Y), kernel, Z=Z, X_variance=X_variance, **kwargs) @@ -70,16 +79,18 @@ class Bayesian_GPLVM(sparse_GP, GPLVM): x = np.hstack((self.X.flatten(), self.X_variance.flatten(), sparse_GP._get_params(self))) return x - def _set_params(self, x, save_old=True): + def _set_params(self, x, save_old=True, save_count=0): try: N, Q = self.N, self.Q self.X = x[:self.X.size].reshape(N, Q).copy() self.X_variance = x[(N * Q):(2 * N * Q)].reshape(N, Q).copy() sparse_GP._set_params(self, x[(2 * N * Q):]) self.oldps = x - except (LinAlgError, FloatingPointError): - print "\rWARNING: Caught LinAlgError, reconstructing old state " - self._set_params(self.oldps[-1], save_old=False) + except (LinAlgError, FloatingPointError, ZeroDivisionError): + print "\rWARNING: Caught LinAlgError, continueing without setting " +# if save_count > 10: +# raise +# self._set_params(self.oldps[-1], save_old=False, save_count=save_count + 1) def dKL_dmuS(self): dKL_dS = (1. - (1. / (self.X_variance))) * 0.5 @@ -103,15 +114,29 @@ class Bayesian_GPLVM(sparse_GP, GPLVM): def log_likelihood(self): ll = sparse_GP.log_likelihood(self) kl = self.KL_divergence() - return ll + kl + +# if ll < -2E4: +# ll = -2E4 + np.random.randn() +# if kl > 5E4: +# kl = 5E4 + np.random.randn() + + if self._debug: + f_call = self._count.next() + self._savedklll.append([f_call, ll, kl]) + if f_call % 1 == 0: + self._savedparams.append([f_call, self._get_params()]) + + + # print "\nkl:", kl, "ll:", ll + return ll - kl def _log_likelihood_gradients(self): dKL_dmu, dKL_dS = self.dKL_dmuS() dL_dmu, dL_dS = self.dL_dmuS() # TODO: find way to make faster - d_dmu = (dL_dmu + dKL_dmu).flatten() - d_dS = (dL_dS + dKL_dS).flatten() + d_dmu = (dL_dmu - dKL_dmu).flatten() + d_dS = (dL_dS - dKL_dS).flatten() # TEST KL: ==================== # d_dmu = (dKL_dmu).flatten() # d_dS = (dKL_dS).flatten() @@ -135,3 +160,140 @@ class Bayesian_GPLVM(sparse_GP, GPLVM): ax = GPLVM.plot_latent(self, which_indices=[input_1, input_2], *args, **kwargs) ax.plot(self.Z[:, input_1], self.Z[:, input_2], '^w') return ax + + def plot_X_1d(self, fig_num="MRD X 1d", axes=None, colors=None): + import pylab + + fig = pylab.figure(num=fig_num, figsize=(min(8, (3 * len(self.bgplvms))), min(12, (2 * self.X.shape[1])))) + if colors is None: + colors = pylab.gca()._get_lines.color_cycle + pylab.clf() + plots = [] + for i in range(self.X.shape[1]): + if axes is None: + ax = fig.add_subplot(self.X.shape[1], 1, i + 1) + else: + ax = axes[i] + ax.plot(self.X, c='k', alpha=.3) + plots.extend(ax.plot(self.X.T[i], c=colors.next(), label=r"$\mathbf{{X_{}}}$".format(i))) + ax.fill_between(np.arange(self.X.shape[0]), + self.X.T[i] - 2 * np.sqrt(self.X_variance.T[i]), + self.X.T[i] + 2 * np.sqrt(self.X_variance.T[i]), + facecolor=plots[-1].get_color(), + alpha=.3) + ax.legend(borderaxespad=0.) + if i < self.X.shape[1] - 1: + ax.set_xticklabels('') + pylab.draw() + fig.tight_layout(h_pad=.01) # , rect=(0, 0, 1, .95)) + return fig + + def _debug_filter_params(self, x): + start, end = 0, self.X.size, + X = x[start:end].reshape(self.N, self.Q) + start, end = end, end + self.X_variance.size + X_v = x[start:end].reshape(self.N, self.Q) + start, end = end, end + (self.M * self.Q) + Z = x[start:end].reshape(self.M, self.Q) + start, end = end, end + self.Q + theta = x[start:] + return X, X_v, Z, theta + + def _debug_plot(self): + assert self._debug, "must enable _debug, to debug-plot" + import pylab + from mpl_toolkits.mplot3d import Axes3D + fig = pylab.figure('BGPLVM DEBUG', figsize=(12, 10)) + fig.clf() + + # log like + splotshape = (6, 4) + ax1 = pylab.subplot2grid(splotshape, (0, 0), 1, 4) + ax1.text(.5, .5, "Optimization", alpha=.3, transform=ax1.transAxes, + ha='center', va='center') + kllls = np.array(self._savedklll) + LL, = ax1.plot(kllls[:, 0], kllls[:, 1] - kllls[:, 2], label=r'$\log p(\mathbf{Y})$', mew=1.5) + KL, = ax1.plot(kllls[:, 0], kllls[:, 2], label=r'$\mathcal{KL}(p||q)$', mew=1.5) + L, = ax1.plot(kllls[:, 0], kllls[:, 1], label=r'$L$', mew=1.5) # \mathds{E}_{q(\mathbf{X})}[p(\mathbf{Y|X})\frac{p(\mathbf{X})}{q(\mathbf{X})}] + + drawn = dict(self._savedparams) + iters = np.array(drawn.keys()) + self.showing = 0 + + ax2 = pylab.subplot2grid(splotshape, (1, 0), 2, 4) + ax2.text(.5, .5, r"$\mathbf{X}$", alpha=.5, transform=ax2.transAxes, + ha='center', va='center') + ax3 = pylab.subplot2grid(splotshape, (3, 0), 2, 4, sharex=ax2) + ax3.text(.5, .5, r"$\mathbf{S}$", alpha=.5, transform=ax3.transAxes, + ha='center', va='center') + ax4 = pylab.subplot2grid(splotshape, (5, 0), 2, 2) + ax4.text(.5, .5, r"$\mathbf{Z}$", alpha=.5, transform=ax4.transAxes, + ha='center', va='center') + ax5 = pylab.subplot2grid(splotshape, (5, 2), 2, 2) + ax5.text(.5, .5, r"${\theta}$", alpha=.5, transform=ax5.transAxes, + ha='center', va='center') + + X, S, Z, theta = self._debug_filter_params(drawn[self.showing]) + Xlatentplts = ax2.plot(X, ls="-", marker="x") + Slatentplts = ax3.plot(S, ls="-", marker="x") + Zplts = ax4.plot(Z, ls="-", marker="x") + thetaplts = ax5.bar(np.arange(len(theta)) - .4, theta) + ax5.set_xticks(np.arange(len(theta))) + ax5.set_xticklabels(self._get_param_names()[-len(theta):], rotation=17) + + Qleg = ax1.legend(Xlatentplts, [r"$Q_{}$".format(i + 1) for i in range(self.Q)], + loc=3, ncol=self.Q, bbox_to_anchor=(0, 1.15, 1, 1.15), + borderaxespad=0, mode="expand") + Lleg = ax1.legend() + Lleg.draggable() + ax1.add_artist(Qleg) + + indicatorKL, = ax1.plot(kllls[self.showing, 0], kllls[self.showing, 2], 'o', c=KL.get_color()) + indicatorLL, = ax1.plot(kllls[self.showing, 0], kllls[self.showing, 1] - kllls[self.showing, 2], 'o', c=LL.get_color()) + indicatorL, = ax1.plot(kllls[self.showing, 0], kllls[self.showing, 1], 'o', c=L.get_color()) + + try: + pylab.draw() + pylab.tight_layout(box=(0, .1, 1, .9)) + except: + pass + + # parameter changes + # ax2 = pylab.subplot2grid((4, 1), (1, 0), 3, 1, projection='3d') + def onclick(event): + if event.inaxes is ax1 and event.button == 1: +# event.button, event.x, event.y, event.xdata, event.ydata) + tmp = np.abs(iters - event.xdata) + closest_hit = iters[tmp == tmp.min()][0] + + if closest_hit != self.showing: + self.showing = closest_hit + # print closest_hit, iters, event.xdata + + indicatorLL.set_data(self.showing, kllls[self.showing, 1] - kllls[self.showing, 2]) + indicatorKL.set_data(self.showing, kllls[self.showing, 2]) + indicatorL.set_data(self.showing, kllls[self.showing, 1]) + + X, S, Z, theta = self._debug_filter_params(drawn[self.showing]) + for i, Xlatent in enumerate(Xlatentplts): + Xlatent.set_ydata(X[:, i]) + for i, Slatent in enumerate(Slatentplts): + Slatent.set_ydata(S[:, i]) + for i, Zlatent in enumerate(Zplts): + Zlatent.set_ydata(Z[:, i]) + for p, t in zip(thetaplts, theta): + p.set_height(t) + + ax2.relim() + ax3.relim() + ax4.relim() + ax5.relim() + ax2.autoscale() + ax3.autoscale() + ax4.autoscale() + ax5.autoscale() + fig.canvas.draw() + + cid = fig.canvas.mpl_connect('button_press_event', onclick) + + return ax1, ax2, ax3, ax4, ax5 diff --git a/GPy/models/mrd.py b/GPy/models/mrd.py index 096c9cb9..4e0487b2 100644 --- a/GPy/models/mrd.py +++ b/GPy/models/mrd.py @@ -287,29 +287,6 @@ class MRD(model): else: return pylab.gcf() - def plot_X_1d(self, fig_num="MRD X 1d", axes=None, colors=None): - fig = pylab.figure(num=fig_num, figsize=(min(8, (3 * len(self.bgplvms))), min(12, (2 * self.X.shape[1])))) - if colors is None: - colors = pylab.gca()._get_lines.color_cycle - pylab.clf() - plots = [] - for i in range(self.X.shape[1]): - if axes is None: - ax = fig.add_subplot(self.X.shape[1], 1, i + 1) - ax.plot(self.X, c='k', alpha=.3) - plots.extend(ax.plot(self.X.T[i], c=colors.next(), label=r"$\mathbf{{X_{}}}$".format(i))) - ax.fill_between(numpy.arange(self.X.shape[0]), - self.X.T[i] - 2 * numpy.sqrt(self.gref.X_variance.T[i]), - self.X.T[i] + 2 * numpy.sqrt(self.gref.X_variance.T[i]), - facecolor=plots[-1].get_color(), - alpha=.3) - ax.legend(borderaxespad=0.) - if i < self.X.shape[1] - 1: - ax.set_xticklabels('') - pylab.draw() - fig.tight_layout(h_pad=.01) # , rect=(0, 0, 1, .95)) - return fig - def plot_X(self, fig_num="MRD Predictions", axes=None): fig = self._handle_plotting(fig_num, axes, lambda i, g, ax: ax.imshow(g.X)) return fig