fixing now works, removing parameters needs fixing

This commit is contained in:
Max Zwiessele 2014-02-13 15:06:05 +00:00
parent d2e8807a88
commit f71af38505
7 changed files with 100 additions and 147 deletions

View file

@ -8,15 +8,10 @@ import cPickle
import itertools
from re import compile, _pattern_type
from param import ParamConcatenation, Param
from parameter_core import Constrainable, Pickleable, Observable, adjust_name_for_printing, Gradcheckable, __fixed__
from parameter_core import Constrainable, Pickleable, Observable, adjust_name_for_printing, Gradcheckable
from transformations import __fixed__, FIXED, UNFIXED
from array_core import ParamList
#===============================================================================
# constants
FIXED = False
UNFIXED = True
#===============================================================================
class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
"""
Parameterized class
@ -71,37 +66,6 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
self._added_names_ = set()
del self._in_init_
#===========================================================================
# Parameter connection for model creation:
#===========================================================================
# def set_as_parameter(self, name, array, gradient, index=None, gradient_parent=None):
# """
# :param name: name of the param (in print and plots), can be callable without parameters
# :type name: str, callable
# :param array: array which the param consists of
# :type array: array-like
# :param gradient: gradient method of the param
# :type gradient: callable
# :param index: (optional) index of the param when printing
#
# (:param gradient_parent: connect these parameters to this class, but tell
# updates to highest_parent, this is needed when parameterized classes
# contain parameterized classes, but want to access the parameters
# of their children)
#
#
# Set array (e.g. self.X) as param with name and gradient.
# I.e: self.set_as_parameter('curvature', self.lengthscale, self.dK_dlengthscale)
#
# Note: the order in which parameters are added can be adjusted by
# giving an index, of where to put this param in printing
# """
# if index is None:
# self._parameters_.append(Param(name, array, gradient))
# else:
# self._parameters_.insert(index, Param(name, array, gradient))
# self._connect_parameters(gradient_parent=gradient_parent)
def _has_fixes(self):
return hasattr(self, "_fixes_") and self._fixes_ is not None
@ -118,53 +82,22 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
# if param.has_parent():
# raise AttributeError, "parameter {} already in another model, create new object (or copy) for adding".format(param._short())
if param in self._parameters_ and index is not None:
# make sure fixes and constraints are indexed right
if self._has_fixes():
param_slice = slice(self._offset_for(param), self._offset_for(param) + param.size)
dest_index = sum((p.size for p in self._parameters_[:index]))
dest_slice = slice(dest_index, dest_index + param.size)
fixes_param = self._fixes_[param_slice].copy()
self._fixes_[param_slice] = self._fixes_[dest_slice]
self._fixes_[dest_slice] = fixes_param
del self._parameters_[param._parent_index_]
self._parameters_.insert(index, param)
self.remove_parameter(param)
self.add_parameter(param, index)
elif param not in self._parameters_:
# make sure the size is set
if not hasattr(self, 'size'):
self.size = sum(p.size for p in self._parameters_)
if index is None:
self._parameters_.append(param)
# make sure fixes and constraints are indexed right
if param._has_fixes(): fixes_param = param._fixes_.copy()
else: fixes_param = numpy.ones(param.size, dtype=bool)
if self._has_fixes(): self._fixes_ = np.r_[self._fixes_, fixes_param]
elif param._has_fixes(): self._fixes_ = np.r_[np.ones(self.size, dtype=bool), fixes_param]
else:
start = sum(p.size for p in self._parameters_[:index])
self.constraints.shift(start, param.size)
self._parameters_.insert(index, param)
# make sure fixes and constraints are indexed right
if param._has_fixes(): fixes_param = param._fixes_.copy()
else: fixes_param = numpy.ones(param.size, dtype=bool)
ins = sum((p.size for p in self._parameters_[:index]))
if self._has_fixes(): self._fixes_ = np.r_[self._fixes_[:ins], fixes_param, self._fixes[ins:]]
elif not np.all(fixes_param):
self._fixes_ = np.ones(self.size + param.size, dtype=bool)
self._fixes_[ins:ins + param.size] = fixes_param
self.size += param.size
else:
raise RuntimeError, """Parameter exists already added and no copy made"""
self._connect_parameters()
for p in self._parameters_:
p._parent_changed(self)
if param._default_constraint_ is not None:
param.constrain(param._default_constraint_, False)
if self._has_fixes() and np.all(self._fixes_): # ==UNFIXED
self._fixes_ = None
def add_parameters(self, *parameters):
"""
@ -173,18 +106,20 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
"""
[self.add_parameter(p) for p in parameters]
def remove_parameter(self, *names_params_indices):
def remove_parameter(self, param):
"""
:param names_params_indices: mix of parameter_names, param objects, or indices
to remove from being a param of this parameterized object.
note: if it is a string object it will not (!) be regexp-matched
automatically.
:param param: param object to remove from being a parameter of this parameterized object.
"""
self._parameters_ = ParamList([p for p in self._parameters_
if not (p._parent_index_ in names_params_indices
or p.name in names_params_indices
or p in names_params_indices)])
if not param in self._parameters_:
raise RuntimeError, "Parameter {} does not belong to this object, remove parameters directly from their respective parents".format(param._short())
del self._parameters_[param._parent_index_]
self.size -= param.size
constr = param.constraints.copy()
param.constraints.clear()
param.constraints = constr
param._direct_parent_ = None
param._parent_index_ = None
param._connect_fixes()
self._connect_parameters()
def _connect_parameters(self):
@ -214,6 +149,8 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
elif not (pname in not_unique):
self.__dict__[pname] = p
self._added_names_.add(pname)
self._connect_fixes()
#===========================================================================
# Pickling operations
#===========================================================================
@ -337,7 +274,7 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
self._added_names_.add(pname)
self.__dict__[pname] = param
#===========================================================================
# Index Handling
# Indexable Handling
#===========================================================================
def _backtranslate_index(self, param, ind):
# translate an index in parameterized indexing into the index of param
@ -373,36 +310,10 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
#===========================================================================
# Fixing parameters:
#===========================================================================
def _set_fixed(self, param_or_index):
if not self._has_fixes(): self._fixes_ = numpy.ones(self.size, dtype=bool)
try:
param_or_index = self._raveled_index_for(param_or_index)
except AttributeError:
pass
self._fixes_[param_or_index] = FIXED
if numpy.all(self._fixes_): self._fixes_ = None # ==UNFIXED
def _set_unfixed(self, param_or_index):
if not self._has_fixes(): self._fixes_ = numpy.ones(self.size, dtype=bool)
try:
param_or_index = self._raveled_index_for(param_or_index)
except AttributeError:
pass
self._fixes_[param_or_index] = UNFIXED
for constr, ind in self.constraints.iteritems():
if constr is __fixed__:
self._fixes_[ind] = FIXED
if numpy.all(self._fixes_): self._fixes_ = None # ==UNFIXED
def _fixes_for(self, param):
if self._has_fixes():
return self._fixes_[self._raveled_index_for(param)]
return numpy.ones(self.size, dtype=bool)[self._raveled_index_for(param)]
# def _fix(self, param, warning=True):
# f = self._add_constrain(param, __fixed__, warning)
# self._set_fixed(f)
# def _unfix(self, param):
# if self._has_fixes():
# f = self._remove_constrain(param, __fixed__)
# self._set_unfixed(f)
#===========================================================================
# Convenience for fixed, tied checking of param:
#===========================================================================