expand_param and extract_param replaced with set_params_transformed and get_params_transformed

This commit is contained in:
Neil Lawrence 2013-01-18 13:37:17 +00:00
parent 7350143dd0
commit 1674bc529b
8 changed files with 54 additions and 54 deletions

View file

@ -66,7 +66,7 @@ class parameterised(object):
if hasattr(self,'prior'):
pass
self.expand_param(self.extract_param())# sets tied parameters to single value
self._set_params_transformed(self._get_params_transformed())# sets tied parameters to single value
def untie_everything(self):
"""Unties all parameters by setting tied_indices to an empty list."""
@ -216,9 +216,9 @@ class parameterised(object):
self.constrained_fixed_values.append(self._get_params()[self.constrained_fixed_indices[-1]])
#self.constrained_fixed_values.append(value)
self.expand_param(self.extract_param())
self._set_params_transformed(self._get_params_transformed())
def extract_param(self):
def _get_params_transformed(self):
"""use self._get_params to get the 'true' parameters of the model, which are then tied, constrained and fixed"""
x = self._get_params()
x[self.constrained_positive_indices] = np.log(x[self.constrained_positive_indices])
@ -232,7 +232,7 @@ class parameterised(object):
return x
def expand_param(self,x):
def _set_params_transformed(self,x):
""" takes the vector x, which is then modified (by untying, reparameterising or inserting fixed values), and then call self._set_params"""
#work out how many places are fixed, and where they are. tricky logic!
@ -259,10 +259,10 @@ class parameterised(object):
[np.put(xx,i,low+sigmoid(xx[i])*(high-low)) for i,low,high in zip(self.constrained_bounded_indices, self.constrained_bounded_lowers, self.constrained_bounded_uppers)]
self._set_params(xx)
def extract_param_names(self):
def _get_param_names_transformed(self):
"""
Returns the parameter names as propagated after constraining,
tying or fixing, i.e. a list of the same length as extract_param()
tying or fixing, i.e. a list of the same length as _get_params_transformed()
"""
n = self._get_param_names()