diff --git a/GPy/core/index_operations.py b/GPy/core/index_operations.py new file mode 100644 index 00000000..d6dbe140 --- /dev/null +++ b/GPy/core/index_operations.py @@ -0,0 +1,50 @@ +''' +Created on Oct 2, 2013 + +@author: maxzwiessele +''' +import numpy + +class ParameterIndexOperations(object): + ''' + Index operations for storing parameter index restrictions + This class enables indexing with slices retrieved from object.__getitem__ calls. + + :param shape: shape of parameter, handled by this index restriction class + ''' + + + def __init__(self, shape): + self.restrictions = {} + self.shape = shape + + def get_restriction_indices(self, restriction): + """ + get indices for restriction restriction. + these indices can be used as X[indices], which will be a flattened array of + all restricted elements + """ + return numpy.unravel_index(self.restrictions[restriction], self.shape) + + def add_restriction(self, restriction, indices): + ind = self._create_raveled_indices(indices) + if restriction in self.restrictions: + self.restrictions[restriction] = numpy.union1d(self.restrictions[restriction], ind) + else: + self.restrictions[restriction] = ind + + def remove_restriction(self, restriction, indices): + if restriction in self.restrictions: + ind = self._create_raveled_indices(indices) + diff = numpy.setdiff1d(self.restrictions[restriction], ind, True) + if numpy.size(diff): + self.restrictions[restriction] = diff + else: + del self.restrictions[restriction] + + def _create_raveled_indices(self, indices): + if isinstance(indices, (tuple, list)): + i = [slice(None)] + list(indices) + else: + i = [slice(None), indices] + return numpy.array(numpy.ravel_multi_index(numpy.indices(self.shape)[i], self.shape)).flatten() diff --git a/GPy/core/parameter.py b/GPy/core/parameter.py index 5c9d2849..e95ce19b 100644 --- a/GPy/core/parameter.py +++ b/GPy/core/parameter.py @@ -7,6 +7,7 @@ import re import itertools import numpy from GPy.core.transformations import Logexp +from GPy.core.index_operations import ParameterIndexOperations _index_re = re.compile('(?:_(\d+))+') # pattern match for indices def translate_param_names_to_parameters(param_names): @@ -73,17 +74,6 @@ class Parameters(object): return '\n'.join([x.__str__(format_spec=format_spec) for x in self._params]) pass -class ParameterIndexing(object): - def __init__(self, corresponding_param): - self.properties = {} - self.param = corresponding_param - def add(self, prop, s): - if prop in self.properties.keys(): - self.properties[prop] = self.combine_indices(self.properties[prop], s) - else: - self.properties[prop] = [numpy.r_[st] for st in s] - def combine_indices(self, s1, s2): - return [numpy.union1d(numpy.r_[ar1], numpy.r_[ar2]) for ar1, ar2 in itertools.izip_longest(s1, s2)] class Parameter(object): tied_to = [] # list of parameters this parameter is tied to @@ -91,7 +81,7 @@ class Parameter(object): def __init__(self, name, value, constraint=None, *args, **kwargs): self.name = name - self.constraints = ParameterIndexing(self) + self.constraints = ParameterIndexOperations(self) self._value = value self._current_slice = slice(None) @@ -129,11 +119,8 @@ class Parameter(object): def _get_params_transformed(self): params = self.value.copy() - import ipdb;ipdb.set_trace() - return def constrain_positive(self): - import ipdb;ipdb.set_trace() self.constraints.add(Logexp(), self._current_slice) self._current_slice = slice(None)