mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-13 14:03:20 +02:00
merge ties branch into psi2
This commit is contained in:
parent
36d53f815c
commit
fa6b5343cc
1 changed files with 80 additions and 0 deletions
80
GPy/core/parameterization/ties_and_remappings.py
Normal file
80
GPy/core/parameterization/ties_and_remappings.py
Normal file
|
|
@ -0,0 +1,80 @@
|
||||||
|
# Copyright (c) 2014, James Hensman, Max Zwiessele
|
||||||
|
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from parameterized import Parameterized
|
||||||
|
from param import Param
|
||||||
|
|
||||||
|
class Remapping(Parameterized):
|
||||||
|
def mapping(self):
|
||||||
|
"""
|
||||||
|
The return value of this function gives the values which the re-mapped
|
||||||
|
parameters should take. Implement in sub-classes.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def callback(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
def parameters_changed(self):
|
||||||
|
#ensure all out parameters have the correct value, as specified by our mapping
|
||||||
|
index = self._highest_parent_.constraints[self]
|
||||||
|
self._highest_parent_.param_array[index] = self.mapping()
|
||||||
|
[p.notify_observers(which=self) for p in self.tied_parameters]
|
||||||
|
|
||||||
|
class Fix(Remapping):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Tie(Remapping):
|
||||||
|
def __init__(self, value, name):
|
||||||
|
super(Tie, self).__init__(name)
|
||||||
|
self.tied_parameters = []
|
||||||
|
self.value = Param('val', value)
|
||||||
|
self.add_parameter(self.value)
|
||||||
|
|
||||||
|
def add_tied_parameter(self, p):
|
||||||
|
self.tied_parameters.append(p)
|
||||||
|
p.add_observer(self, self.callback)
|
||||||
|
self.parameters_changed()
|
||||||
|
|
||||||
|
def callback(self, param=None, which=None):
|
||||||
|
"""
|
||||||
|
This gets called whenever any of the tied parameters changes. we spend
|
||||||
|
considerable effort working out whhat has changed ant to what value.
|
||||||
|
Then we store that value in self.value, and broadcast it everywhere
|
||||||
|
with parameters_changed.
|
||||||
|
"""
|
||||||
|
if which is self:return
|
||||||
|
index = self._highest_parent_.constraints[self]
|
||||||
|
if len(index)==0:
|
||||||
|
return # nothing to tie together, this tie exists without any tied parameters
|
||||||
|
vals = self._highest_parent_.param_array[index]
|
||||||
|
uvals = np.unique(vals)
|
||||||
|
if len(uvals)==1:
|
||||||
|
#all of the tied things are at the same value
|
||||||
|
self.value[...] = uvals[0]
|
||||||
|
elif len(uvals)==2:
|
||||||
|
#only *one* of the tied things has changed. it must be different to self.value
|
||||||
|
newval = uvals[uvals != self.value*1]
|
||||||
|
self.value[...] = newval
|
||||||
|
else:
|
||||||
|
#more than one of the tied things changed. panic.
|
||||||
|
raise ValueError, "something is wrong with the tieing"
|
||||||
|
|
||||||
|
def mapping(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
def collate_gradient(self):
|
||||||
|
index = self._highest_parent_.constraints[self]
|
||||||
|
self.value.gradient = np.sum(self._highest_parent_.gradient[index])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue