diff --git a/GPy/core/index_operations.py b/GPy/core/index_operations.py index 63856a6d..15bea339 100644 --- a/GPy/core/index_operations.py +++ b/GPy/core/index_operations.py @@ -4,9 +4,37 @@ Created on Oct 2, 2013 @author: maxzwiessele ''' import numpy -import itertools +from numpy.lib.function_base import vectorize +from parameter import Param -class ConstraintIndexOperations(object): +class ParamDict(dict): + def __getitem__(self, key): + try: + return super(ParamDict, self).__getitem__(key) + except KeyError: + for a in self.iterkeys(): + if numpy.all(a==key) and a._parent_index_==key._parent_index_: + return super(ParamDict, self).__getitem__(a) + raise + + def __contains__(self, key): + if super(ParamDict, self).__contains__(key): + return True + for a in self.iterkeys(): + if numpy.all(a==key) and a._parent_index_==key._parent_index_: + return True + return False + + def __setitem__(self, key, value): + if isinstance(key, Param): + for a in self.iterkeys(): + if numpy.all(a==key) and a._parent_index_==key._parent_index_: + return super(ParamDict, self).__setitem__(a, value) + raise KeyError, key + super(ParamDict, self).__setitem__(key, value) + + +class ParameterIndexOperations(object): ''' Index operations for storing parameter index _properties This class enables index with slices retrieved from object.__getitem__ calls. @@ -14,13 +42,19 @@ class ConstraintIndexOperations(object): indexing a shape shaped array to the flattened index array. Remove will remove the selected slice indices from the flattened array. You can give an offset to set an offset for the given indices in the - index array, for multiparameter handling. - - + index array, for multi-parameter handling. ''' def __init__(self): - self._properties = {} + self._properties = ParamDict() + #self._reverse = collections.defaultdict(list) + def __getstate__(self): + return self._properties, self._reverse + + def __setstate__(self, state): + self._properties = state[0] + self._reverse = state[1] + def iteritems(self): return self._properties.iteritems() @@ -41,39 +75,50 @@ class ConstraintIndexOperations(object): def indices(self): return self._properties.values() + + def properties_for(self, index): +# already_seen = dict() +# for ni in index: +# if ni not in already_seen: +# already_seen[ni] = [prop for prop in self.iter_properties() if ni in self._properties[prop]] +# yield already_seen[ni] + return vectorize(lambda i: [prop for prop in self.iter_properties() if i in self._properties[prop]], otypes=[list])(index) def add(self, prop, indices, shape, offset=False): ind = create_raveled_indices(indices, shape, offset) - if prop in self._properties: + #[self._reverse[i].__add__(prop) for i in ind] + try: self._properties[prop] = combine_indices(self._properties[prop], ind) - return - for a in self.properties(): - if numpy.all(a==prop) and a._parent_index == prop._parent_index: - self._properties[a] = combine_indices(self._properties[a], ind) - return - self._properties[prop] = ind + except KeyError: +# for a in self.properties(): +# if numpy.all(a==prop) and a._parent_index_ == prop._parent_index_: +# self._properties[a] = combine_indices(self._properties[a], ind) +# return + self._properties[prop] = ind def remove(self, prop, indices, shape, offset=False): if prop in self._properties: - ind = create_raveled_indices(indices, shape, offset) - diff = remove_indices(self[prop], ind) - removed = numpy.intersect1d(self[prop], ind, True) - if not index_empty(diff): - self._properties[prop] = diff - else: - del self._properties[prop] - return removed.astype(int) - else: - for a in self.properties(): - if numpy.all(a==prop) and a._parent_index == prop._parent_index: - ind = create_raveled_indices(indices, shape, offset) - diff = remove_indices(self[a], ind) - removed = numpy.intersect1d(self[a], ind, True) - if not index_empty(diff): - self._properties[a] = diff - else: - del self._properties[a] - return removed.astype(int) + ind = create_raveled_indices(indices, shape, offset) + diff = remove_indices(self[prop], ind) + removed = numpy.intersect1d(self[prop], ind, True) + if not index_empty(diff): + self._properties[prop] = diff + else: + del self._properties[prop] + #[self._reverse[i].remove(prop) for i in removed if prop in self._reverse[i]] + return removed.astype(int) +# else: +# for a in self.properties(): +# if numpy.all(a==prop) and a._parent_index_ == prop._parent_index_: +# ind = create_raveled_indices(indices, shape, offset) +# diff = remove_indices(self[a], ind) +# removed = numpy.intersect1d(self[a], ind, True) +# if not index_empty(diff): +# self._properties[a] = diff +# else: +# del self._properties[a] +# [self._reverse[i].remove(a) for i in removed if a in self._reverse[i]] +# return removed.astype(int) return numpy.array([]).astype(int) def __getitem__(self, prop): return self._properties[prop] @@ -81,17 +126,23 @@ class ConstraintIndexOperations(object): class TieIndexOperations(object): def __init__(self, params): self.params = params - self.tied_from = ConstraintIndexOperations() - self.tied_to = ConstraintIndexOperations() + self.tied_from = ParameterIndexOperations() + self.tied_to = ParameterIndexOperations() def add(self, tied_from, tied_to): - self.tied_from.add(tied_to, tied_from._current_slice, tied_from._realshape, self.params._offset(tied_from)) - self.tied_to.add(tied_to, tied_to._current_slice, tied_to._realshape, self.params._offset(tied_to)) + self.tied_from.add(tied_to, tied_from._current_slice_, tied_from._realshape_, self.params._offset(tied_from)) + self.tied_to.add(tied_to, tied_to._current_slice_, tied_to._realshape_, self.params._offset(tied_to)) def remove(self, tied_from, tied_to): - self.tied_from.remove(tied_to, tied_from._current_slice, tied_from._realshape, self.params._offset(tied_from)) - self.tied_to.remove(tied_to, tied_to._current_slice, tied_to._realshape, self.params._offset(tied_to)) + self.tied_from.remove(tied_to, tied_from._current_slice_, tied_from._realshape_, self.params._offset(tied_from)) + self.tied_to.remove(tied_to, tied_to._current_slice_, tied_to._realshape_, self.params._offset(tied_to)) + def from_to_for(self, index): + return self.tied_from.properties_for(index), self.tied_to.properties_for(index) def iter_from_to_indices(self): for k, f in self.tied_from.iteritems(): yield f, self.tied_to[k] + def iter_to_indices(self): + return self.tied_to.iterindices() + def iter_from_indices(self): + return self.tied_from.iterindices() def iter_from_items(self): for f, i in self.tied_from.iteritems(): yield f, i @@ -102,7 +153,7 @@ class TieIndexOperations(object): def from_to_indices(self, param): return self.tied_from[param], self.tied_to[param] -def create_raveled_indices(index, shape, offset=False): +def create_raveled_indices(index, shape, offset=0): if isinstance(index, (tuple, list)): i = [slice(None)] + list(index) else: i = [slice(None), index] ind = numpy.array(numpy.ravel_multi_index(numpy.indices(shape)[i], shape)).flat + numpy.int_(offset)