From fa6b5343cc649bac10fefba7f72df7087ecb1b9c Mon Sep 17 00:00:00 2001 From: Zhenwen Dai Date: Mon, 23 Jun 2014 14:37:11 +0100 Subject: [PATCH] merge ties branch into psi2 --- .../parameterization/ties_and_remappings.py | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 GPy/core/parameterization/ties_and_remappings.py diff --git a/GPy/core/parameterization/ties_and_remappings.py b/GPy/core/parameterization/ties_and_remappings.py new file mode 100644 index 00000000..75b46a95 --- /dev/null +++ b/GPy/core/parameterization/ties_and_remappings.py @@ -0,0 +1,80 @@ +# Copyright (c) 2014, James Hensman, Max Zwiessele +# Licensed under the BSD 3-clause license (see LICENSE.txt) + +import numpy as np +from parameterized import Parameterized +from param import Param + +class Remapping(Parameterized): + def mapping(self): + """ + The return value of this function gives the values which the re-mapped + parameters should take. Implement in sub-classes. + """ + raise NotImplementedError + + def callback(self): + raise NotImplementedError + + def __str__(self): + return self.name + + def parameters_changed(self): + #ensure all out parameters have the correct value, as specified by our mapping + index = self._highest_parent_.constraints[self] + self._highest_parent_.param_array[index] = self.mapping() + [p.notify_observers(which=self) for p in self.tied_parameters] + +class Fix(Remapping): + pass + + + + +class Tie(Remapping): + def __init__(self, value, name): + super(Tie, self).__init__(name) + self.tied_parameters = [] + self.value = Param('val', value) + self.add_parameter(self.value) + + def add_tied_parameter(self, p): + self.tied_parameters.append(p) + p.add_observer(self, self.callback) + self.parameters_changed() + + def callback(self, param=None, which=None): + """ + This gets called whenever any of the tied parameters changes. we spend + considerable effort working out whhat has changed ant to what value. + Then we store that value in self.value, and broadcast it everywhere + with parameters_changed. + """ + if which is self:return + index = self._highest_parent_.constraints[self] + if len(index)==0: + return # nothing to tie together, this tie exists without any tied parameters + vals = self._highest_parent_.param_array[index] + uvals = np.unique(vals) + if len(uvals)==1: + #all of the tied things are at the same value + self.value[...] = uvals[0] + elif len(uvals)==2: + #only *one* of the tied things has changed. it must be different to self.value + newval = uvals[uvals != self.value*1] + self.value[...] = newval + else: + #more than one of the tied things changed. panic. + raise ValueError, "something is wrong with the tieing" + + def mapping(self): + return self.value + + def collate_gradient(self): + index = self._highest_parent_.constraints[self] + self.value.gradient = np.sum(self._highest_parent_.gradient[index]) + + + + +