index operations finalized

This commit is contained in:
Max Zwiessele 2013-10-02 19:13:37 +01:00
parent e8437f7ec3
commit 269cc84253

View file

@ -7,44 +7,54 @@ import numpy
class ParameterIndexOperations(object):
'''
Index operations for storing parameter index restrictions
Index operations for storing parameter index _properties
This class enables indexing with slices retrieved from object.__getitem__ calls.
:param shape: shape of parameter, handled by this index restriction class
:param _shape: _shape of parameter, handled by this index restriction class
'''
def __init__(self, shape):
self.restrictions = {}
self.shape = shape
def __init__(self, param):
self._properties = {}
self._shape = param.shape
def get_restriction_indices(self, restriction):
def iteritems(self):
for prop, indices in self._properties.iteritems():
yield prop, numpy.unravel_index(indices, self._shape)
def keys(self):
return self._properties.keys()
def items(self):
return self._properties.items()
def indices(self, prop):
"""
get indices for restriction restriction.
get indices for prop prop.
these indices can be used as X[indices], which will be a flattened array of
all restricted elements
"""
return numpy.unravel_index(self.restrictions[restriction], self.shape)
return numpy.unravel_index(self._properties[prop], self._shape)
def add_restriction(self, restriction, indices):
ind = self._create_raveled_indices(indices)
if restriction in self.restrictions:
self.restrictions[restriction] = numpy.union1d(self.restrictions[restriction], ind)
def add(self, prop, indices):
ind = self.create_raveled_indices(indices)
if prop in self._properties:
self._properties[prop] = numpy.union1d(self._properties[prop], ind)
else:
self.restrictions[restriction] = ind
self._properties[prop] = ind
def remove_restriction(self, restriction, indices):
if restriction in self.restrictions:
ind = self._create_raveled_indices(indices)
diff = numpy.setdiff1d(self.restrictions[restriction], ind, True)
def remove(self, prop, indices):
if prop in self._properties:
ind = self.create_raveled_indices(indices)
diff = numpy.setdiff1d(self._properties[prop], ind, True)
if numpy.size(diff):
self.restrictions[restriction] = diff
self._properties[prop] = diff
else:
del self.restrictions[restriction]
del self._properties[prop]
def _create_raveled_indices(self, indices):
def create_raveled_indices(self, indices):
if isinstance(indices, (tuple, list)):
i = [slice(None)] + list(indices)
else:
i = [slice(None), indices]
return numpy.array(numpy.ravel_multi_index(numpy.indices(self.shape)[i], self.shape)).flatten()
return numpy.array(numpy.ravel_multi_index(numpy.indices(self._shape)[i], self._shape)).flatten()