mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-03 00:32:39 +02:00
Added option for plotting with SVGP
This commit is contained in:
parent
afa0621488
commit
d3e79495e7
1 changed files with 9 additions and 3 deletions
|
|
@ -11,6 +11,7 @@ from base_plots import gpplot, x_frame1D, x_frame2D
|
|||
from ...models.gp_coregionalized_regression import GPCoregionalizedRegression
|
||||
from ...models.sparse_gp_coregionalized_regression import SparseGPCoregionalizedRegression
|
||||
from scipy import sparse
|
||||
from ...core.parameterization.variational import VariationalPosterior
|
||||
|
||||
def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||
which_data_ycols='all', fixed_inputs=[],
|
||||
|
|
@ -78,7 +79,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
|||
|
||||
if predict_kw is None:
|
||||
predict_kw = {}
|
||||
|
||||
|
||||
#work out what the inputs are for plotting (1D or 2D)
|
||||
fixed_dims = np.array([i for i,v in fixed_inputs])
|
||||
free_dims = np.setdiff1d(np.arange(model.input_dim),fixed_dims)
|
||||
|
|
@ -219,7 +220,7 @@ def plot_fit_f(model, *args, **kwargs):
|
|||
kwargs['plot_raw'] = True
|
||||
plot_fit(model,*args, **kwargs)
|
||||
|
||||
def fixed_inputs(model, non_fixed_inputs, fix_routine='median', as_list=True):
|
||||
def fixed_inputs(model, non_fixed_inputs, fix_routine='median', as_list=True, X_all=False):
|
||||
"""
|
||||
Convenience function for returning back fixed_inputs where the other inputs
|
||||
are fixed using fix_routine
|
||||
|
|
@ -235,8 +236,13 @@ def fixed_inputs(model, non_fixed_inputs, fix_routine='median', as_list=True):
|
|||
f_inputs = []
|
||||
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
|
||||
X = model.X.mean.values.copy()
|
||||
else:
|
||||
elif isinstance(model.X, VariationalPosterior):
|
||||
X = model.X.values.copy()
|
||||
else:
|
||||
if X_all:
|
||||
X = model.X_all.copy()
|
||||
else:
|
||||
X = model.X.copy()
|
||||
for i in range(X.shape[1]):
|
||||
if i not in non_fixed_inputs:
|
||||
if fix_routine == 'mean':
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue