hierarchy edits. adding removing parameters withing hierarchy

This commit is contained in:
Max Zwiessele 2014-02-28 16:18:47 +00:00
parent c87bda9e49
commit 47e4026141
11 changed files with 106 additions and 64 deletions

View file

@ -7,7 +7,7 @@ import cPickle
import itertools
from re import compile, _pattern_type
from param import ParamConcatenation
from parameter_core import Constrainable, Pickleable, Parentable, Observable, Parameterizable, adjust_name_for_printing, Gradcheckable
from parameter_core import Pickleable, Parameterizable, adjust_name_for_printing, Gradcheckable
from transformations import __fixed__
from array_core import ParamList
@ -105,6 +105,14 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
self.remove_parameter(param)
self.add_parameter(param, index)
elif param not in self._parameters_:
if param.has_parent():
parent = param._direct_parent_
while parent is not None:
if parent is self:
from GPy.core.parameterization.parameter_core import HierarchyError
raise HierarchyError, "You cannot add a parameter twice into the hirarchy"
parent = parent._direct_parent_
param._direct_parent_.remove_parameter(param)
# make sure the size is set
if index is None:
self.constraints.update(param.constraints, self.size)
@ -117,13 +125,16 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
self.constraints.update(param.constraints, start)
self.priors.update(param.priors, start)
self._parameters_.insert(index, param)
param.add_observer(self, self._pass_through_notify_observers, -np.inf)
self.size += param.size
self._connect_parameters()
self._notify_parent_change()
self._connect_fixes()
else:
raise RuntimeError, """Parameter exists already added and no copy made"""
self._connect_parameters()
self._notify_parent_change()
self._connect_fixes()
def add_parameters(self, *parameters):
@ -146,12 +157,19 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
del self._parameters_[param._parent_index_]
param._disconnect_parent()
param.remove_observer(self, self._notify_parameters_changed)
param.remove_observer(self, self._pass_through_notify_observers)
self.constraints.shift_left(start, param.size)
self._connect_fixes()
self._connect_parameters()
self._notify_parent_change()
parent = self._direct_parent_
while parent is not None:
parent._connect_fixes()
parent._connect_parameters()
parent._notify_parent_change()
parent = parent._direct_parent_
def _connect_parameters(self):
# connect parameterlist to this parameterized object
@ -351,7 +369,7 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
# Printing:
#===========================================================================
def _short(self):
return self.hirarchy_name()
return self.hierarchy_name()
@property
def flattened_parameters(self):
return [xi for x in self._parameters_ for xi in x.flattened_parameters]
@ -359,11 +377,6 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
def _parameter_sizes_(self):
return [x.size for x in self._parameters_]
@property
def size_transformed(self):
if self._has_fixes():
return sum(self._fixes_)
return self.size
@property
def parameter_shapes(self):
return [xi for x in self._parameters_ for xi in x.parameter_shapes]
@property