mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-21 14:05:14 +02:00
Small convenience function for extracting fixed_inputs, fixed inputs can be set to their mean, median, or zero
This commit is contained in:
parent
9defc07672
commit
597eddcc81
2 changed files with 53 additions and 3 deletions
|
|
@ -1,5 +1,5 @@
|
|||
#===============================================================================
|
||||
# Copyright (c) 2015, Max Zwiessele
|
||||
# Copyright (c) 2016, Max Zwiessele, Alan saul
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
|
|
@ -131,6 +131,7 @@ def fixed_inputs(model, non_fixed_inputs, fix_routine='median', as_list=True, X_
|
|||
:param as_list: if true, will return a list of tuples with (dimension, fixed_val) otherwise it will create the corresponding X matrix
|
||||
:type as_list: boolean
|
||||
"""
|
||||
from ...inference.latent_function_inference.posterior import VariationalPosterior
|
||||
f_inputs = []
|
||||
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
|
||||
X = model.X.mean.values.copy()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
#===============================================================================
|
||||
# Copyright (c) 2016, Max Zwiessele
|
||||
# Copyright (c) 2016, Max Zwiessele, Alan Saul
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
|
|
@ -47,3 +47,52 @@ class TestDebug(unittest.TestCase):
|
|||
|
||||
array = np.random.normal(0, 1, (25,25))
|
||||
self.assertTrue(checkFullRank(tdot(array)))
|
||||
|
||||
def test_fixed_inputs_median(self):
|
||||
""" test fixed_inputs convenience function """
|
||||
from GPy.plotting.matplot_dep.util import fixed_inputs
|
||||
import GPy
|
||||
X = np.random.randn(10, 3)
|
||||
Y = np.sin(X) + np.random.randn(10, 3)*1e-3
|
||||
m = GPy.models.GPRegression(X, Y)
|
||||
fixed = fixed_inputs(m, [1], fix_routine='median', as_list=True, X_all=False)
|
||||
self.assertTrue((0, np.median(X[:,0])) in fixed)
|
||||
self.assertTrue((2, np.median(X[:,2])) in fixed)
|
||||
self.assertTrue(len([t for t in fixed if t[0] == 1]) == 0) # Unfixed input should not be in fixed
|
||||
|
||||
def test_fixed_inputs_mean(self):
|
||||
from GPy.plotting.matplot_dep.util import fixed_inputs
|
||||
import GPy
|
||||
X = np.random.randn(10, 3)
|
||||
Y = np.sin(X) + np.random.randn(10, 3)*1e-3
|
||||
m = GPy.models.GPRegression(X, Y)
|
||||
fixed = fixed_inputs(m, [1], fix_routine='mean', as_list=True, X_all=False)
|
||||
self.assertTrue((0, np.mean(X[:,0])) in fixed)
|
||||
self.assertTrue((2, np.mean(X[:,2])) in fixed)
|
||||
self.assertTrue(len([t for t in fixed if t[0] == 1]) == 0) # Unfixed input should not be in fixed
|
||||
|
||||
def test_fixed_inputs_zero(self):
|
||||
from GPy.plotting.matplot_dep.util import fixed_inputs
|
||||
import GPy
|
||||
X = np.random.randn(10, 3)
|
||||
Y = np.sin(X) + np.random.randn(10, 3)*1e-3
|
||||
m = GPy.models.GPRegression(X, Y)
|
||||
fixed = fixed_inputs(m, [1], fix_routine='zero', as_list=True, X_all=False)
|
||||
self.assertTrue((0, 0.0) in fixed)
|
||||
self.assertTrue((2, 0.0) in fixed)
|
||||
self.assertTrue(len([t for t in fixed if t[0] == 1]) == 0) # Unfixed input should not be in fixed
|
||||
|
||||
def test_fixed_inputs_uncertain(self):
|
||||
from GPy.plotting.matplot_dep.util import fixed_inputs
|
||||
import GPy
|
||||
from GPy.core.parameterization.variational import NormalPosterior
|
||||
X_mu = np.random.randn(10, 3)
|
||||
X_var = np.random.randn(10, 3)
|
||||
X = NormalPosterior(X_mu, X_var)
|
||||
Y = np.sin(X_mu) + np.random.randn(10, 3)*1e-3
|
||||
m = GPy.models.BayesianGPLVM(Y, X=X_mu, X_variance=X_var, input_dim=3)
|
||||
fixed = fixed_inputs(m, [1], fix_routine='median', as_list=True, X_all=False)
|
||||
self.assertTrue((0, np.median(X.mean.values[:,0])) in fixed)
|
||||
self.assertTrue((2, np.median(X.mean.values[:,2])) in fixed)
|
||||
self.assertTrue(len([t for t in fixed if t[0] == 1]) == 0) # Unfixed input should not be in fixed
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue