mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-11 15:15:15 +02:00
index operations finalized
This commit is contained in:
parent
e8437f7ec3
commit
269cc84253
1 changed files with 31 additions and 21 deletions
|
|
@ -7,44 +7,54 @@ import numpy
|
|||
|
||||
class ParameterIndexOperations(object):
|
||||
'''
|
||||
Index operations for storing parameter index restrictions
|
||||
Index operations for storing parameter index _properties
|
||||
This class enables indexing with slices retrieved from object.__getitem__ calls.
|
||||
|
||||
:param shape: shape of parameter, handled by this index restriction class
|
||||
:param _shape: _shape of parameter, handled by this index restriction class
|
||||
'''
|
||||
|
||||
|
||||
def __init__(self, shape):
|
||||
self.restrictions = {}
|
||||
self.shape = shape
|
||||
def __init__(self, param):
|
||||
self._properties = {}
|
||||
self._shape = param.shape
|
||||
|
||||
def get_restriction_indices(self, restriction):
|
||||
def iteritems(self):
|
||||
for prop, indices in self._properties.iteritems():
|
||||
yield prop, numpy.unravel_index(indices, self._shape)
|
||||
|
||||
def keys(self):
|
||||
return self._properties.keys()
|
||||
|
||||
def items(self):
|
||||
return self._properties.items()
|
||||
|
||||
def indices(self, prop):
|
||||
"""
|
||||
get indices for restriction restriction.
|
||||
get indices for prop prop.
|
||||
these indices can be used as X[indices], which will be a flattened array of
|
||||
all restricted elements
|
||||
"""
|
||||
return numpy.unravel_index(self.restrictions[restriction], self.shape)
|
||||
return numpy.unravel_index(self._properties[prop], self._shape)
|
||||
|
||||
def add_restriction(self, restriction, indices):
|
||||
ind = self._create_raveled_indices(indices)
|
||||
if restriction in self.restrictions:
|
||||
self.restrictions[restriction] = numpy.union1d(self.restrictions[restriction], ind)
|
||||
def add(self, prop, indices):
|
||||
ind = self.create_raveled_indices(indices)
|
||||
if prop in self._properties:
|
||||
self._properties[prop] = numpy.union1d(self._properties[prop], ind)
|
||||
else:
|
||||
self.restrictions[restriction] = ind
|
||||
self._properties[prop] = ind
|
||||
|
||||
def remove_restriction(self, restriction, indices):
|
||||
if restriction in self.restrictions:
|
||||
ind = self._create_raveled_indices(indices)
|
||||
diff = numpy.setdiff1d(self.restrictions[restriction], ind, True)
|
||||
def remove(self, prop, indices):
|
||||
if prop in self._properties:
|
||||
ind = self.create_raveled_indices(indices)
|
||||
diff = numpy.setdiff1d(self._properties[prop], ind, True)
|
||||
if numpy.size(diff):
|
||||
self.restrictions[restriction] = diff
|
||||
self._properties[prop] = diff
|
||||
else:
|
||||
del self.restrictions[restriction]
|
||||
del self._properties[prop]
|
||||
|
||||
def _create_raveled_indices(self, indices):
|
||||
def create_raveled_indices(self, indices):
|
||||
if isinstance(indices, (tuple, list)):
|
||||
i = [slice(None)] + list(indices)
|
||||
else:
|
||||
i = [slice(None), indices]
|
||||
return numpy.array(numpy.ravel_multi_index(numpy.indices(self.shape)[i], self.shape)).flatten()
|
||||
return numpy.array(numpy.ravel_multi_index(numpy.indices(self._shape)[i], self._shape)).flatten()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue