Added back fixed_inputs

This commit is contained in:
Alan Saul 2016-03-21 17:46:18 +00:00
parent 5c53bc45e2
commit d8447a1c65
2 changed files with 38 additions and 2 deletions

View file

@ -235,8 +235,6 @@ def plot_density(self, plot_limits=None, fixed_inputs=None,
Give the Y_metadata in the predict_kw if you need it.
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
:type plot_limits: np.array
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input dimension i should be set to value v.

View file

@ -117,3 +117,41 @@ def align_subplot_array(axes,xlim=None, ylim=None):
ax.set_xticks([])
else:
removeUpperTicks(ax)
def fixed_inputs(model, non_fixed_inputs, fix_routine='median', as_list=True, X_all=False):
"""
Convenience function for returning back fixed_inputs where the other inputs
are fixed using fix_routine
:param model: model
:type model: Model
:param non_fixed_inputs: dimensions of non fixed inputs
:type non_fixed_inputs: list
:param fix_routine: fixing routine to use, 'mean', 'median', 'zero'
:type fix_routine: string
:param as_list: if true, will return a list of tuples with (dimension, fixed_val) otherwise it will create the corresponding X matrix
:type as_list: boolean
"""
f_inputs = []
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
X = model.X.mean.values.copy()
elif isinstance(model.X, VariationalPosterior):
X = model.X.values.copy()
else:
if X_all:
X = model.X_all.copy()
else:
X = model.X.copy()
for i in range(X.shape[1]):
if i not in non_fixed_inputs:
if fix_routine == 'mean':
f_inputs.append( (i, np.mean(X[:,i])) )
if fix_routine == 'median':
f_inputs.append( (i, np.median(X[:,i])) )
else: # set to zero zero
f_inputs.append( (i, 0) )
if not as_list:
X[:,i] = f_inputs[-1][1]
if as_list:
return f_inputs
else:
return X