From 68f0af4deb139041fb68782df26e192b8248be38 Mon Sep 17 00:00:00 2001 From: James Hensman Date: Wed, 26 Feb 2014 17:13:47 +0000 Subject: [PATCH] fixes in the plotting and in the dot graphing --- GPy/core/parameterization/param.py | 6 +++++- GPy/core/parameterization/parameterized.py | 6 +++++- GPy/plotting/matplot_dep/base_plots.py | 20 ++++++++++---------- GPy/plotting/matplot_dep/models_plots.py | 2 +- 4 files changed, 21 insertions(+), 13 deletions(-) diff --git a/GPy/core/parameterization/param.py b/GPy/core/parameterization/param.py index b48bec0d..3b08eebd 100644 --- a/GPy/core/parameterization/param.py +++ b/GPy/core/parameterization/param.py @@ -65,7 +65,11 @@ class Param(Constrainable, ObservableArray, Gradcheckable): node = pydot.Node(id(self), shape='record', label=self.name) G.add_node(node) for o in self._observer_callables_.keys(): - print o, self.hirarchy_name() + label = o.name if hasattr(o, 'name') else str(o) + observed_node = pydot.Node(id(o), label=label) + G.add_node(observed_node) + edge = pydot.Edge(str(id(self)), str(id(o)), color='darkorange2', arrowhead='vee') + G.add_edge(edge) return node diff --git a/GPy/core/parameterization/parameterized.py b/GPy/core/parameterization/parameterized.py index 83a62161..fb606278 100644 --- a/GPy/core/parameterization/parameterized.py +++ b/GPy/core/parameterization/parameterized.py @@ -77,7 +77,11 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable): G.add_edge(pydot.Edge(node, child_node)) for o in self._observer_callables_.keys(): - print id(o), self.hirarchy_name() + label = o.name if hasattr(o, 'name') else str(o) + observed_node = pydot.Node(id(o), label=label) + G.add_node(observed_node) + edge = pydot.Edge(str(id(self)), str(id(o)), color='darkorange2', arrowhead='vee') + G.add_edge(edge) if iamroot: return G diff --git a/GPy/plotting/matplot_dep/base_plots.py b/GPy/plotting/matplot_dep/base_plots.py index a9d25223..e86ef6ca 100644 --- a/GPy/plotting/matplot_dep/base_plots.py +++ b/GPy/plotting/matplot_dep/base_plots.py @@ -18,17 +18,17 @@ def meanplot(x, mu, color=Tango.colorsHex['darkBlue'], ax=None, fignum=None, lin _, axes = ax_default(fignum, ax) #here's the mean return axes.plot(x,mu,color=color,linewidth=linewidth,**kw) - + def gpplot(x,mu,lower,upper,edgecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'],ax=None,fignum=None,xlabel='x',ylabel='y',**kwargs): - _, axes = ax_default(ax, fignum) - + _, axes = ax_default(fignum, ax) + mu = mu.flatten() x = x.flatten() lower = lower.flatten() upper = upper.flatten() plots = [] - + #here's the mean plots.append(meanplot(x, mu, edgecol, axes)) @@ -39,15 +39,15 @@ def gpplot(x,mu,lower,upper,edgecol=Tango.colorsHex['darkBlue'],fillcol=Tango.co plots.append(axes.fill(np.hstack((x,x[::-1])),np.hstack((upper,lower[::-1])),color=fillcol,**kwargs)) #this is the edge: - plots.append(meanplot(x, upper,color=edgecol,linewidth=0.2,axes=axes)) - plots.append(meanplot(x, lower,color=edgecol,linewidth=0.2,axes=axes)) - + plots.append(meanplot(x, upper,color=edgecol,linewidth=0.2,ax=axes)) + plots.append(meanplot(x, lower,color=edgecol,linewidth=0.2,ax=axes)) + axes.set_xlabel(xlabel) axes.set_ylabel(ylabel) - + return plots - - + + def removeRightTicks(ax=None): ax = ax or pb.gca() for i, line in enumerate(ax.get_yticklines()): diff --git a/GPy/plotting/matplot_dep/models_plots.py b/GPy/plotting/matplot_dep/models_plots.py index 5123f514..d72d2a3e 100644 --- a/GPy/plotting/matplot_dep/models_plots.py +++ b/GPy/plotting/matplot_dep/models_plots.py @@ -86,7 +86,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all', m, v, lower, upper = model.predict(Xgrid) Y = Y for d in which_data_ycols: - gpplot(Xnew, m[:, d], lower[:, d], upper[:, d], axes=ax, edgecol=linecol, fillcol=fillcol) + gpplot(Xnew, m[:, d], lower[:, d], upper[:, d], ax=ax, edgecol=linecol, fillcol=fillcol) ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], 'kx', mew=1.5) #optionally plot some samples