Small convenience function for extracting fixed_inputs, fixed inputs can be set to their mean, median, or zero

This commit is contained in:
Alan Saul 2016-03-24 15:24:53 +00:00
parent 9defc07672
commit 597eddcc81
2 changed files with 53 additions and 3 deletions

View file

@ -1,5 +1,5 @@
#=============================================================================== #===============================================================================
# Copyright (c) 2015, Max Zwiessele # Copyright (c) 2016, Max Zwiessele, Alan saul
# All rights reserved. # All rights reserved.
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
@ -131,6 +131,7 @@ def fixed_inputs(model, non_fixed_inputs, fix_routine='median', as_list=True, X_
:param as_list: if true, will return a list of tuples with (dimension, fixed_val) otherwise it will create the corresponding X matrix :param as_list: if true, will return a list of tuples with (dimension, fixed_val) otherwise it will create the corresponding X matrix
:type as_list: boolean :type as_list: boolean
""" """
from ...inference.latent_function_inference.posterior import VariationalPosterior
f_inputs = [] f_inputs = []
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs(): if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
X = model.X.mean.values.copy() X = model.X.mean.values.copy()

View file

@ -1,5 +1,5 @@
#=============================================================================== #===============================================================================
# Copyright (c) 2016, Max Zwiessele # Copyright (c) 2016, Max Zwiessele, Alan Saul
# All rights reserved. # All rights reserved.
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
@ -46,4 +46,53 @@ class TestDebug(unittest.TestCase):
self.assertFalse(checkFullRank(tdot(array), name='test')) self.assertFalse(checkFullRank(tdot(array), name='test'))
array = np.random.normal(0, 1, (25,25)) array = np.random.normal(0, 1, (25,25))
self.assertTrue(checkFullRank(tdot(array))) self.assertTrue(checkFullRank(tdot(array)))
def test_fixed_inputs_median(self):
""" test fixed_inputs convenience function """
from GPy.plotting.matplot_dep.util import fixed_inputs
import GPy
X = np.random.randn(10, 3)
Y = np.sin(X) + np.random.randn(10, 3)*1e-3
m = GPy.models.GPRegression(X, Y)
fixed = fixed_inputs(m, [1], fix_routine='median', as_list=True, X_all=False)
self.assertTrue((0, np.median(X[:,0])) in fixed)
self.assertTrue((2, np.median(X[:,2])) in fixed)
self.assertTrue(len([t for t in fixed if t[0] == 1]) == 0) # Unfixed input should not be in fixed
def test_fixed_inputs_mean(self):
from GPy.plotting.matplot_dep.util import fixed_inputs
import GPy
X = np.random.randn(10, 3)
Y = np.sin(X) + np.random.randn(10, 3)*1e-3
m = GPy.models.GPRegression(X, Y)
fixed = fixed_inputs(m, [1], fix_routine='mean', as_list=True, X_all=False)
self.assertTrue((0, np.mean(X[:,0])) in fixed)
self.assertTrue((2, np.mean(X[:,2])) in fixed)
self.assertTrue(len([t for t in fixed if t[0] == 1]) == 0) # Unfixed input should not be in fixed
def test_fixed_inputs_zero(self):
from GPy.plotting.matplot_dep.util import fixed_inputs
import GPy
X = np.random.randn(10, 3)
Y = np.sin(X) + np.random.randn(10, 3)*1e-3
m = GPy.models.GPRegression(X, Y)
fixed = fixed_inputs(m, [1], fix_routine='zero', as_list=True, X_all=False)
self.assertTrue((0, 0.0) in fixed)
self.assertTrue((2, 0.0) in fixed)
self.assertTrue(len([t for t in fixed if t[0] == 1]) == 0) # Unfixed input should not be in fixed
def test_fixed_inputs_uncertain(self):
from GPy.plotting.matplot_dep.util import fixed_inputs
import GPy
from GPy.core.parameterization.variational import NormalPosterior
X_mu = np.random.randn(10, 3)
X_var = np.random.randn(10, 3)
X = NormalPosterior(X_mu, X_var)
Y = np.sin(X_mu) + np.random.randn(10, 3)*1e-3
m = GPy.models.BayesianGPLVM(Y, X=X_mu, X_variance=X_var, input_dim=3)
fixed = fixed_inputs(m, [1], fix_routine='median', as_list=True, X_all=False)
self.assertTrue((0, np.median(X.mean.values[:,0])) in fixed)
self.assertTrue((2, np.median(X.mean.values[:,2])) in fixed)
self.assertTrue(len([t for t in fixed if t[0] == 1]) == 0) # Unfixed input should not be in fixed