mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
some initial work on ties
This commit is contained in:
parent
37e46b5da3
commit
c554c91339
5 changed files with 80 additions and 24 deletions
|
|
@ -1,22 +1,48 @@
|
||||||
'''
|
# Copyright (c) 2014, Max Zwiessele
|
||||||
Created on Oct 2, 2013
|
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||||
|
|
||||||
@author: maxzwiessele
|
|
||||||
'''
|
|
||||||
import numpy
|
import numpy
|
||||||
from numpy.lib.function_base import vectorize
|
from numpy.lib.function_base import vectorize
|
||||||
from lists_and_dicts import IntArrayDict
|
from lists_and_dicts import IntArrayDict
|
||||||
|
|
||||||
class ParameterIndexOperations(object):
|
class ParameterIndexOperations(object):
|
||||||
'''
|
"""
|
||||||
Index operations for storing param index _properties
|
This object wraps a dictionary, whos keys are _operations_ that we'd like
|
||||||
This class enables index with slices retrieved from object.__getitem__ calls.
|
to apply to a parameter array, and whose values are np integer arrays which
|
||||||
Adding an index will add the selected indexes by the slice of an indexarray
|
index the parameter array appropriately.
|
||||||
indexing a shape shaped array to the flattened index array. Remove will
|
|
||||||
remove the selected slice indices from the flattened array.
|
A model instance will contain one instance of this class for each thing
|
||||||
You can give an offset to set an offset for the given indices in the
|
that needs indexing (i.e. constraints, ties and priors). Parameters within
|
||||||
index array, for multi-param handling.
|
the model constain instances of the ParameterIndexOperationsView class,
|
||||||
'''
|
which can map from a 'local' index (starting 0) to this global index.
|
||||||
|
|
||||||
|
Here's an illustration:
|
||||||
|
|
||||||
|
#=======================================================================
|
||||||
|
model : 0 1 2 3 4 5 6 7 8 9
|
||||||
|
key1: 4 5
|
||||||
|
key2: 7 8
|
||||||
|
|
||||||
|
param1: 0 1 2 3 4 5
|
||||||
|
key1: 2 3
|
||||||
|
key2: 5
|
||||||
|
|
||||||
|
param2: 0 1 2 3 4
|
||||||
|
key1: 0
|
||||||
|
key2: 2 3
|
||||||
|
#=======================================================================
|
||||||
|
|
||||||
|
The views of this global index have a subset of the keys in this global
|
||||||
|
(model) index.
|
||||||
|
|
||||||
|
Adding a new key (e.g. a constraint) to a view will cause the view to pass
|
||||||
|
the new key to the global index, along with the local index and an offset.
|
||||||
|
This global index then stores the key and the appropriate global index
|
||||||
|
(which can be seen by the view).
|
||||||
|
|
||||||
|
See also:
|
||||||
|
ParameterIndexOperationsView
|
||||||
|
|
||||||
|
"""
|
||||||
_offset = 0
|
_offset = 0
|
||||||
def __init__(self, constraints=None):
|
def __init__(self, constraints=None):
|
||||||
self._properties = IntArrayDict()
|
self._properties = IntArrayDict()
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
|
# Copyright (c) 2014, GPy authors (see AUTHORS.txt).
|
||||||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
|
#t Copyright (c) 2012, GPy authors (see AUTHORS.txt).
|
||||||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||||
"""
|
"""
|
||||||
Core module for parameterization.
|
Core module for parameterization.
|
||||||
|
|
@ -471,6 +471,8 @@ class Constrainable(Nameable, Indexable, Observable):
|
||||||
|
|
||||||
Constrain the parameter to the given
|
Constrain the parameter to the given
|
||||||
:py:class:`GPy.core.transformations.Transformation`.
|
:py:class:`GPy.core.transformations.Transformation`.
|
||||||
|
|
||||||
|
:returns added: the indices that were constrained
|
||||||
"""
|
"""
|
||||||
self.param_array[...] = transform.initialize(self.param_array)
|
self.param_array[...] = transform.initialize(self.param_array)
|
||||||
reconstrained = self.unconstrain()
|
reconstrained = self.unconstrain()
|
||||||
|
|
@ -478,6 +480,37 @@ class Constrainable(Nameable, Indexable, Observable):
|
||||||
self.notify_observers(self, None if trigger_parent else -np.inf)
|
self.notify_observers(self, None if trigger_parent else -np.inf)
|
||||||
return added
|
return added
|
||||||
|
|
||||||
|
def tie(self, name):
|
||||||
|
#remove any constraints
|
||||||
|
old_const = self.constraints.properties()[:]
|
||||||
|
self.unconstrain()
|
||||||
|
|
||||||
|
#set these parameters to be 'fixed' as in, not optimized
|
||||||
|
self._highest_parent_._set_fixed(self, self._raveled_index())
|
||||||
|
|
||||||
|
#see if a tie exists with that name
|
||||||
|
if name in self._highest_parent_.ties:
|
||||||
|
t = self._highest_parent_.ties[name]
|
||||||
|
else:
|
||||||
|
#create a tie object
|
||||||
|
value = np.atleast_1d(self.param_array)[0]*1
|
||||||
|
import ties_and_remappings
|
||||||
|
t = ties_and_remappings.Tie(value=value, name=name)
|
||||||
|
|
||||||
|
#add the new tie object to the global index
|
||||||
|
self._highest_parent_.ties[name] = t
|
||||||
|
self._highest_parent_.add_parameter(t)
|
||||||
|
|
||||||
|
#constrain the tie as we were constrained
|
||||||
|
if len(old_const)==1:
|
||||||
|
t.constrain(old_const[0])
|
||||||
|
|
||||||
|
|
||||||
|
self.constraints.add(t, self._raveled_index())
|
||||||
|
t.add_tied_parameter(self)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def unconstrain(self, *transforms):
|
def unconstrain(self, *transforms):
|
||||||
"""
|
"""
|
||||||
:param transforms: The transformations to unconstrain from.
|
:param transforms: The transformations to unconstrain from.
|
||||||
|
|
@ -712,6 +745,7 @@ class Parameterizable(OptimizationHandlable):
|
||||||
self._parameters_ = ArrayList()
|
self._parameters_ = ArrayList()
|
||||||
self.size = 0
|
self.size = 0
|
||||||
self._added_names_ = set()
|
self._added_names_ = set()
|
||||||
|
self.ties = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def param_array(self):
|
def param_array(self):
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
|
# Copyright (c) 2014, Max Zwiessele
|
||||||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||||
|
|
||||||
|
|
||||||
import numpy; np = numpy
|
import numpy; np = numpy
|
||||||
import cPickle
|
import cPickle
|
||||||
import itertools
|
import itertools
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,5 @@
|
||||||
'''
|
# Copyright (c) 2014, Max Zwiessele
|
||||||
Created on 12 Feb 2014
|
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||||
|
|
||||||
@author: maxz
|
|
||||||
'''
|
|
||||||
import unittest
|
import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from GPy.core.parameterization.index_operations import ParameterIndexOperations,\
|
from GPy.core.parameterization.index_operations import ParameterIndexOperations,\
|
||||||
|
|
@ -99,4 +96,4 @@ class Test(unittest.TestCase):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
#import sys;sys.argv = ['', 'Test.test_index_view']
|
#import sys;sys.argv = ['', 'Test.test_index_view']
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue