mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-08 03:22:38 +02:00
fixes in the plotting and in the dot graphing
This commit is contained in:
parent
0e5a453358
commit
68f0af4deb
4 changed files with 21 additions and 13 deletions
|
|
@ -65,7 +65,11 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
|
||||||
node = pydot.Node(id(self), shape='record', label=self.name)
|
node = pydot.Node(id(self), shape='record', label=self.name)
|
||||||
G.add_node(node)
|
G.add_node(node)
|
||||||
for o in self._observer_callables_.keys():
|
for o in self._observer_callables_.keys():
|
||||||
print o, self.hirarchy_name()
|
label = o.name if hasattr(o, 'name') else str(o)
|
||||||
|
observed_node = pydot.Node(id(o), label=label)
|
||||||
|
G.add_node(observed_node)
|
||||||
|
edge = pydot.Edge(str(id(self)), str(id(o)), color='darkorange2', arrowhead='vee')
|
||||||
|
G.add_edge(edge)
|
||||||
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -77,7 +77,11 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
||||||
G.add_edge(pydot.Edge(node, child_node))
|
G.add_edge(pydot.Edge(node, child_node))
|
||||||
|
|
||||||
for o in self._observer_callables_.keys():
|
for o in self._observer_callables_.keys():
|
||||||
print id(o), self.hirarchy_name()
|
label = o.name if hasattr(o, 'name') else str(o)
|
||||||
|
observed_node = pydot.Node(id(o), label=label)
|
||||||
|
G.add_node(observed_node)
|
||||||
|
edge = pydot.Edge(str(id(self)), str(id(o)), color='darkorange2', arrowhead='vee')
|
||||||
|
G.add_edge(edge)
|
||||||
|
|
||||||
if iamroot:
|
if iamroot:
|
||||||
return G
|
return G
|
||||||
|
|
|
||||||
|
|
@ -18,17 +18,17 @@ def meanplot(x, mu, color=Tango.colorsHex['darkBlue'], ax=None, fignum=None, lin
|
||||||
_, axes = ax_default(fignum, ax)
|
_, axes = ax_default(fignum, ax)
|
||||||
#here's the mean
|
#here's the mean
|
||||||
return axes.plot(x,mu,color=color,linewidth=linewidth,**kw)
|
return axes.plot(x,mu,color=color,linewidth=linewidth,**kw)
|
||||||
|
|
||||||
def gpplot(x,mu,lower,upper,edgecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'],ax=None,fignum=None,xlabel='x',ylabel='y',**kwargs):
|
def gpplot(x,mu,lower,upper,edgecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'],ax=None,fignum=None,xlabel='x',ylabel='y',**kwargs):
|
||||||
_, axes = ax_default(ax, fignum)
|
_, axes = ax_default(fignum, ax)
|
||||||
|
|
||||||
mu = mu.flatten()
|
mu = mu.flatten()
|
||||||
x = x.flatten()
|
x = x.flatten()
|
||||||
lower = lower.flatten()
|
lower = lower.flatten()
|
||||||
upper = upper.flatten()
|
upper = upper.flatten()
|
||||||
|
|
||||||
plots = []
|
plots = []
|
||||||
|
|
||||||
#here's the mean
|
#here's the mean
|
||||||
plots.append(meanplot(x, mu, edgecol, axes))
|
plots.append(meanplot(x, mu, edgecol, axes))
|
||||||
|
|
||||||
|
|
@ -39,15 +39,15 @@ def gpplot(x,mu,lower,upper,edgecol=Tango.colorsHex['darkBlue'],fillcol=Tango.co
|
||||||
plots.append(axes.fill(np.hstack((x,x[::-1])),np.hstack((upper,lower[::-1])),color=fillcol,**kwargs))
|
plots.append(axes.fill(np.hstack((x,x[::-1])),np.hstack((upper,lower[::-1])),color=fillcol,**kwargs))
|
||||||
|
|
||||||
#this is the edge:
|
#this is the edge:
|
||||||
plots.append(meanplot(x, upper,color=edgecol,linewidth=0.2,axes=axes))
|
plots.append(meanplot(x, upper,color=edgecol,linewidth=0.2,ax=axes))
|
||||||
plots.append(meanplot(x, lower,color=edgecol,linewidth=0.2,axes=axes))
|
plots.append(meanplot(x, lower,color=edgecol,linewidth=0.2,ax=axes))
|
||||||
|
|
||||||
axes.set_xlabel(xlabel)
|
axes.set_xlabel(xlabel)
|
||||||
axes.set_ylabel(ylabel)
|
axes.set_ylabel(ylabel)
|
||||||
|
|
||||||
return plots
|
return plots
|
||||||
|
|
||||||
|
|
||||||
def removeRightTicks(ax=None):
|
def removeRightTicks(ax=None):
|
||||||
ax = ax or pb.gca()
|
ax = ax or pb.gca()
|
||||||
for i, line in enumerate(ax.get_yticklines()):
|
for i, line in enumerate(ax.get_yticklines()):
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
m, v, lower, upper = model.predict(Xgrid)
|
m, v, lower, upper = model.predict(Xgrid)
|
||||||
Y = Y
|
Y = Y
|
||||||
for d in which_data_ycols:
|
for d in which_data_ycols:
|
||||||
gpplot(Xnew, m[:, d], lower[:, d], upper[:, d], axes=ax, edgecol=linecol, fillcol=fillcol)
|
gpplot(Xnew, m[:, d], lower[:, d], upper[:, d], ax=ax, edgecol=linecol, fillcol=fillcol)
|
||||||
ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], 'kx', mew=1.5)
|
ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], 'kx', mew=1.5)
|
||||||
|
|
||||||
#optionally plot some samples
|
#optionally plot some samples
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue