merge ties branch into psi2

This commit is contained in:
Zhenwen Dai 2014-06-23 14:37:11 +01:00
parent 36d53f815c
commit fa6b5343cc

View file

@ -0,0 +1,80 @@
# Copyright (c) 2014, James Hensman, Max Zwiessele
# Licensed under the BSD 3-clause license (see LICENSE.txt)
import numpy as np
from parameterized import Parameterized
from param import Param
class Remapping(Parameterized):
def mapping(self):
"""
The return value of this function gives the values which the re-mapped
parameters should take. Implement in sub-classes.
"""
raise NotImplementedError
def callback(self):
raise NotImplementedError
def __str__(self):
return self.name
def parameters_changed(self):
#ensure all out parameters have the correct value, as specified by our mapping
index = self._highest_parent_.constraints[self]
self._highest_parent_.param_array[index] = self.mapping()
[p.notify_observers(which=self) for p in self.tied_parameters]
class Fix(Remapping):
pass
class Tie(Remapping):
def __init__(self, value, name):
super(Tie, self).__init__(name)
self.tied_parameters = []
self.value = Param('val', value)
self.add_parameter(self.value)
def add_tied_parameter(self, p):
self.tied_parameters.append(p)
p.add_observer(self, self.callback)
self.parameters_changed()
def callback(self, param=None, which=None):
"""
This gets called whenever any of the tied parameters changes. we spend
considerable effort working out whhat has changed ant to what value.
Then we store that value in self.value, and broadcast it everywhere
with parameters_changed.
"""
if which is self:return
index = self._highest_parent_.constraints[self]
if len(index)==0:
return # nothing to tie together, this tie exists without any tied parameters
vals = self._highest_parent_.param_array[index]
uvals = np.unique(vals)
if len(uvals)==1:
#all of the tied things are at the same value
self.value[...] = uvals[0]
elif len(uvals)==2:
#only *one* of the tied things has changed. it must be different to self.value
newval = uvals[uvals != self.value*1]
self.value[...] = newval
else:
#more than one of the tied things changed. panic.
raise ValueError, "something is wrong with the tieing"
def mapping(self):
return self.value
def collate_gradient(self):
index = self._highest_parent_.constraints[self]
self.value.gradient = np.sum(self._highest_parent_.gradient[index])