fixes in the plotting and in the dot graphing

This commit is contained in:
James Hensman 2014-02-26 17:13:47 +00:00
parent 0e5a453358
commit 68f0af4deb
4 changed files with 21 additions and 13 deletions

View file

@ -65,7 +65,11 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
node = pydot.Node(id(self), shape='record', label=self.name) node = pydot.Node(id(self), shape='record', label=self.name)
G.add_node(node) G.add_node(node)
for o in self._observer_callables_.keys(): for o in self._observer_callables_.keys():
print o, self.hirarchy_name() label = o.name if hasattr(o, 'name') else str(o)
observed_node = pydot.Node(id(o), label=label)
G.add_node(observed_node)
edge = pydot.Edge(str(id(self)), str(id(o)), color='darkorange2', arrowhead='vee')
G.add_edge(edge)
return node return node

View file

@ -77,7 +77,11 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
G.add_edge(pydot.Edge(node, child_node)) G.add_edge(pydot.Edge(node, child_node))
for o in self._observer_callables_.keys(): for o in self._observer_callables_.keys():
print id(o), self.hirarchy_name() label = o.name if hasattr(o, 'name') else str(o)
observed_node = pydot.Node(id(o), label=label)
G.add_node(observed_node)
edge = pydot.Edge(str(id(self)), str(id(o)), color='darkorange2', arrowhead='vee')
G.add_edge(edge)
if iamroot: if iamroot:
return G return G

View file

@ -20,7 +20,7 @@ def meanplot(x, mu, color=Tango.colorsHex['darkBlue'], ax=None, fignum=None, lin
return axes.plot(x,mu,color=color,linewidth=linewidth,**kw) return axes.plot(x,mu,color=color,linewidth=linewidth,**kw)
def gpplot(x,mu,lower,upper,edgecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'],ax=None,fignum=None,xlabel='x',ylabel='y',**kwargs): def gpplot(x,mu,lower,upper,edgecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'],ax=None,fignum=None,xlabel='x',ylabel='y',**kwargs):
_, axes = ax_default(ax, fignum) _, axes = ax_default(fignum, ax)
mu = mu.flatten() mu = mu.flatten()
x = x.flatten() x = x.flatten()
@ -39,8 +39,8 @@ def gpplot(x,mu,lower,upper,edgecol=Tango.colorsHex['darkBlue'],fillcol=Tango.co
plots.append(axes.fill(np.hstack((x,x[::-1])),np.hstack((upper,lower[::-1])),color=fillcol,**kwargs)) plots.append(axes.fill(np.hstack((x,x[::-1])),np.hstack((upper,lower[::-1])),color=fillcol,**kwargs))
#this is the edge: #this is the edge:
plots.append(meanplot(x, upper,color=edgecol,linewidth=0.2,axes=axes)) plots.append(meanplot(x, upper,color=edgecol,linewidth=0.2,ax=axes))
plots.append(meanplot(x, lower,color=edgecol,linewidth=0.2,axes=axes)) plots.append(meanplot(x, lower,color=edgecol,linewidth=0.2,ax=axes))
axes.set_xlabel(xlabel) axes.set_xlabel(xlabel)
axes.set_ylabel(ylabel) axes.set_ylabel(ylabel)

View file

@ -86,7 +86,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
m, v, lower, upper = model.predict(Xgrid) m, v, lower, upper = model.predict(Xgrid)
Y = Y Y = Y
for d in which_data_ycols: for d in which_data_ycols:
gpplot(Xnew, m[:, d], lower[:, d], upper[:, d], axes=ax, edgecol=linecol, fillcol=fillcol) gpplot(Xnew, m[:, d], lower[:, d], upper[:, d], ax=ax, edgecol=linecol, fillcol=fillcol)
ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], 'kx', mew=1.5) ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], 'kx', mew=1.5)
#optionally plot some samples #optionally plot some samples