mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-04 01:02:39 +02:00
fixed a plotting bug
This commit is contained in:
parent
90df33af50
commit
c8b54ff887
1 changed files with 5 additions and 8 deletions
|
|
@ -1,17 +1,14 @@
|
||||||
# Copyright (c) 2012-2015, GPy authors (see AUTHORS.txt).
|
# Copyright (c) 2012-2015, GPy authors (see AUTHORS.txt).
|
||||||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||||
|
|
||||||
try:
|
|
||||||
# import Tango
|
|
||||||
import pylab as pb
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from . import Tango
|
||||||
from base_plots import gpplot, x_frame1D, x_frame2D
|
from base_plots import gpplot, x_frame1D, x_frame2D
|
||||||
from ...models.gp_coregionalized_regression import GPCoregionalizedRegression
|
from ...models.gp_coregionalized_regression import GPCoregionalizedRegression
|
||||||
from ...models.sparse_gp_coregionalized_regression import SparseGPCoregionalizedRegression
|
from ...models.sparse_gp_coregionalized_regression import SparseGPCoregionalizedRegression
|
||||||
from scipy import sparse
|
from scipy import sparse
|
||||||
from ...core.parameterization.variational import VariationalPosterior
|
from ...core.parameterization.variational import VariationalPosterior
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
def plot_fit(model, plot_limits=None, which_data_rows='all',
|
def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
which_data_ycols='all', fixed_inputs=[],
|
which_data_ycols='all', fixed_inputs=[],
|
||||||
|
|
@ -64,7 +61,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
#if len(which_data_ycols)==0:
|
#if len(which_data_ycols)==0:
|
||||||
#raise ValueError('No data selected for plotting')
|
#raise ValueError('No data selected for plotting')
|
||||||
if ax is None:
|
if ax is None:
|
||||||
fig = pb.figure(num=fignum)
|
fig = plt.figure(num=fignum)
|
||||||
ax = fig.add_subplot(111)
|
ax = fig.add_subplot(111)
|
||||||
|
|
||||||
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
|
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
|
||||||
|
|
@ -197,8 +194,8 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
m, v = model.predict(Xgrid, full_cov=False, Y_metadata=Y_metadata, **predict_kw)
|
m, v = model.predict(Xgrid, full_cov=False, Y_metadata=Y_metadata, **predict_kw)
|
||||||
for d in which_data_ycols:
|
for d in which_data_ycols:
|
||||||
m_d = m[:,d].reshape(resolution, resolution).T
|
m_d = m[:,d].reshape(resolution, resolution).T
|
||||||
plots['contour'] = ax.contour(x, y, m_d, levels, vmin=m.min(), vmax=m.max(), cmap=pb.cm.jet)
|
plots['contour'] = ax.contour(x, y, m_d, levels, vmin=m.min(), vmax=m.max(), cmap=plt.cm.jet)
|
||||||
if not plot_raw: plots['dataplot'] = ax.scatter(X[which_data_rows, free_dims[0]], X[which_data_rows, free_dims[1]], 40, Y[which_data_rows, d], cmap=pb.cm.jet, vmin=m.min(), vmax=m.max(), linewidth=0.)
|
if not plot_raw: plots['dataplot'] = ax.scatter(X[which_data_rows, free_dims[0]], X[which_data_rows, free_dims[1]], 40, Y[which_data_rows, d], cmap=plt.cm.jet, vmin=m.min(), vmax=m.max(), linewidth=0.)
|
||||||
|
|
||||||
#set the limits of the plot to some sensible values
|
#set the limits of the plot to some sensible values
|
||||||
ax.set_xlim(xmin[0], xmax[0])
|
ax.set_xlim(xmin[0], xmax[0])
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue