mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-21 14:05:14 +02:00
Small convenience function for extracting fixed_inputs, fixed inputs can be set to their mean, median, or zero
This commit is contained in:
parent
9defc07672
commit
597eddcc81
2 changed files with 53 additions and 3 deletions
|
|
@ -1,5 +1,5 @@
|
||||||
#===============================================================================
|
#===============================================================================
|
||||||
# Copyright (c) 2015, Max Zwiessele
|
# Copyright (c) 2016, Max Zwiessele, Alan saul
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
#
|
#
|
||||||
# Redistribution and use in source and binary forms, with or without
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
|
@ -131,6 +131,7 @@ def fixed_inputs(model, non_fixed_inputs, fix_routine='median', as_list=True, X_
|
||||||
:param as_list: if true, will return a list of tuples with (dimension, fixed_val) otherwise it will create the corresponding X matrix
|
:param as_list: if true, will return a list of tuples with (dimension, fixed_val) otherwise it will create the corresponding X matrix
|
||||||
:type as_list: boolean
|
:type as_list: boolean
|
||||||
"""
|
"""
|
||||||
|
from ...inference.latent_function_inference.posterior import VariationalPosterior
|
||||||
f_inputs = []
|
f_inputs = []
|
||||||
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
|
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
|
||||||
X = model.X.mean.values.copy()
|
X = model.X.mean.values.copy()
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
#===============================================================================
|
#===============================================================================
|
||||||
# Copyright (c) 2016, Max Zwiessele
|
# Copyright (c) 2016, Max Zwiessele, Alan Saul
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
#
|
#
|
||||||
# Redistribution and use in source and binary forms, with or without
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
|
@ -46,4 +46,53 @@ class TestDebug(unittest.TestCase):
|
||||||
self.assertFalse(checkFullRank(tdot(array), name='test'))
|
self.assertFalse(checkFullRank(tdot(array), name='test'))
|
||||||
|
|
||||||
array = np.random.normal(0, 1, (25,25))
|
array = np.random.normal(0, 1, (25,25))
|
||||||
self.assertTrue(checkFullRank(tdot(array)))
|
self.assertTrue(checkFullRank(tdot(array)))
|
||||||
|
|
||||||
|
def test_fixed_inputs_median(self):
|
||||||
|
""" test fixed_inputs convenience function """
|
||||||
|
from GPy.plotting.matplot_dep.util import fixed_inputs
|
||||||
|
import GPy
|
||||||
|
X = np.random.randn(10, 3)
|
||||||
|
Y = np.sin(X) + np.random.randn(10, 3)*1e-3
|
||||||
|
m = GPy.models.GPRegression(X, Y)
|
||||||
|
fixed = fixed_inputs(m, [1], fix_routine='median', as_list=True, X_all=False)
|
||||||
|
self.assertTrue((0, np.median(X[:,0])) in fixed)
|
||||||
|
self.assertTrue((2, np.median(X[:,2])) in fixed)
|
||||||
|
self.assertTrue(len([t for t in fixed if t[0] == 1]) == 0) # Unfixed input should not be in fixed
|
||||||
|
|
||||||
|
def test_fixed_inputs_mean(self):
|
||||||
|
from GPy.plotting.matplot_dep.util import fixed_inputs
|
||||||
|
import GPy
|
||||||
|
X = np.random.randn(10, 3)
|
||||||
|
Y = np.sin(X) + np.random.randn(10, 3)*1e-3
|
||||||
|
m = GPy.models.GPRegression(X, Y)
|
||||||
|
fixed = fixed_inputs(m, [1], fix_routine='mean', as_list=True, X_all=False)
|
||||||
|
self.assertTrue((0, np.mean(X[:,0])) in fixed)
|
||||||
|
self.assertTrue((2, np.mean(X[:,2])) in fixed)
|
||||||
|
self.assertTrue(len([t for t in fixed if t[0] == 1]) == 0) # Unfixed input should not be in fixed
|
||||||
|
|
||||||
|
def test_fixed_inputs_zero(self):
|
||||||
|
from GPy.plotting.matplot_dep.util import fixed_inputs
|
||||||
|
import GPy
|
||||||
|
X = np.random.randn(10, 3)
|
||||||
|
Y = np.sin(X) + np.random.randn(10, 3)*1e-3
|
||||||
|
m = GPy.models.GPRegression(X, Y)
|
||||||
|
fixed = fixed_inputs(m, [1], fix_routine='zero', as_list=True, X_all=False)
|
||||||
|
self.assertTrue((0, 0.0) in fixed)
|
||||||
|
self.assertTrue((2, 0.0) in fixed)
|
||||||
|
self.assertTrue(len([t for t in fixed if t[0] == 1]) == 0) # Unfixed input should not be in fixed
|
||||||
|
|
||||||
|
def test_fixed_inputs_uncertain(self):
|
||||||
|
from GPy.plotting.matplot_dep.util import fixed_inputs
|
||||||
|
import GPy
|
||||||
|
from GPy.core.parameterization.variational import NormalPosterior
|
||||||
|
X_mu = np.random.randn(10, 3)
|
||||||
|
X_var = np.random.randn(10, 3)
|
||||||
|
X = NormalPosterior(X_mu, X_var)
|
||||||
|
Y = np.sin(X_mu) + np.random.randn(10, 3)*1e-3
|
||||||
|
m = GPy.models.BayesianGPLVM(Y, X=X_mu, X_variance=X_var, input_dim=3)
|
||||||
|
fixed = fixed_inputs(m, [1], fix_routine='median', as_list=True, X_all=False)
|
||||||
|
self.assertTrue((0, np.median(X.mean.values[:,0])) in fixed)
|
||||||
|
self.assertTrue((2, np.median(X.mean.values[:,2])) in fixed)
|
||||||
|
self.assertTrue(len([t for t in fixed if t[0] == 1]) == 0) # Unfixed input should not be in fixed
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue