diff --git a/GPy/core/index_operations.py b/GPy/core/index_operations.py index d6dbe140..1527c989 100644 --- a/GPy/core/index_operations.py +++ b/GPy/core/index_operations.py @@ -7,44 +7,54 @@ import numpy class ParameterIndexOperations(object): ''' - Index operations for storing parameter index restrictions + Index operations for storing parameter index _properties This class enables indexing with slices retrieved from object.__getitem__ calls. - :param shape: shape of parameter, handled by this index restriction class + :param _shape: _shape of parameter, handled by this index restriction class ''' - def __init__(self, shape): - self.restrictions = {} - self.shape = shape + def __init__(self, param): + self._properties = {} + self._shape = param.shape - def get_restriction_indices(self, restriction): + def iteritems(self): + for prop, indices in self._properties.iteritems(): + yield prop, numpy.unravel_index(indices, self._shape) + + def keys(self): + return self._properties.keys() + + def items(self): + return self._properties.items() + + def indices(self, prop): """ - get indices for restriction restriction. + get indices for prop prop. these indices can be used as X[indices], which will be a flattened array of all restricted elements """ - return numpy.unravel_index(self.restrictions[restriction], self.shape) + return numpy.unravel_index(self._properties[prop], self._shape) - def add_restriction(self, restriction, indices): - ind = self._create_raveled_indices(indices) - if restriction in self.restrictions: - self.restrictions[restriction] = numpy.union1d(self.restrictions[restriction], ind) + def add(self, prop, indices): + ind = self.create_raveled_indices(indices) + if prop in self._properties: + self._properties[prop] = numpy.union1d(self._properties[prop], ind) else: - self.restrictions[restriction] = ind + self._properties[prop] = ind - def remove_restriction(self, restriction, indices): - if restriction in self.restrictions: - ind = self._create_raveled_indices(indices) - diff = numpy.setdiff1d(self.restrictions[restriction], ind, True) + def remove(self, prop, indices): + if prop in self._properties: + ind = self.create_raveled_indices(indices) + diff = numpy.setdiff1d(self._properties[prop], ind, True) if numpy.size(diff): - self.restrictions[restriction] = diff + self._properties[prop] = diff else: - del self.restrictions[restriction] + del self._properties[prop] - def _create_raveled_indices(self, indices): + def create_raveled_indices(self, indices): if isinstance(indices, (tuple, list)): i = [slice(None)] + list(indices) else: i = [slice(None), indices] - return numpy.array(numpy.ravel_multi_index(numpy.indices(self.shape)[i], self.shape)).flatten() + return numpy.array(numpy.ravel_multi_index(numpy.indices(self._shape)[i], self._shape)).flatten()