Added back fixed_inputs

This commit is contained in:
Alan Saul 2016-03-21 17:46:18 +00:00
parent 5c53bc45e2
commit d8447a1c65
2 changed files with 38 additions and 2 deletions

View file

@ -117,3 +117,41 @@ def align_subplot_array(axes,xlim=None, ylim=None):
ax.set_xticks([])
else:
removeUpperTicks(ax)
def fixed_inputs(model, non_fixed_inputs, fix_routine='median', as_list=True, X_all=False):
"""
Convenience function for returning back fixed_inputs where the other inputs
are fixed using fix_routine
:param model: model
:type model: Model
:param non_fixed_inputs: dimensions of non fixed inputs
:type non_fixed_inputs: list
:param fix_routine: fixing routine to use, 'mean', 'median', 'zero'
:type fix_routine: string
:param as_list: if true, will return a list of tuples with (dimension, fixed_val) otherwise it will create the corresponding X matrix
:type as_list: boolean
"""
f_inputs = []
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
X = model.X.mean.values.copy()
elif isinstance(model.X, VariationalPosterior):
X = model.X.values.copy()
else:
if X_all:
X = model.X_all.copy()
else:
X = model.X.copy()
for i in range(X.shape[1]):
if i not in non_fixed_inputs:
if fix_routine == 'mean':
f_inputs.append( (i, np.mean(X[:,i])) )
if fix_routine == 'median':
f_inputs.append( (i, np.median(X[:,i])) )
else: # set to zero zero
f_inputs.append( (i, 0) )
if not as_list:
X[:,i] = f_inputs[-1][1]
if as_list:
return f_inputs
else:
return X