GPy/GPy/core/parameterization/index_operations.py

235 lines
6.8 KiB
Python
Raw Normal View History

'''
Created on Oct 2, 2013
@author: maxzwiessele
'''
import numpy
2013-10-16 21:07:54 +01:00
from numpy.lib.function_base import vectorize
from param import Param
from collections import defaultdict
class ParamDict(defaultdict):
2013-12-07 18:45:24 +00:00
def __init__(self):
"""
Default will be self._default, if not set otherwise
"""
defaultdict.__init__(self, self.default_factory)
2013-10-16 21:07:54 +01:00
def __getitem__(self, key):
try:
return defaultdict.__getitem__(self, key)
2013-10-16 21:07:54 +01:00
except KeyError:
for a in self.iterkeys():
if numpy.all(a==key) and a._parent_index_==key._parent_index_:
return defaultdict.__getitem__(self, a)
2013-10-16 21:07:54 +01:00
raise
def __contains__(self, key):
if defaultdict.__contains__(self, key):
2013-10-16 21:07:54 +01:00
return True
for a in self.iterkeys():
if numpy.all(a==key) and a._parent_index_==key._parent_index_:
return True
return False
def __setitem__(self, key, value):
if isinstance(key, Param):
for a in self.iterkeys():
if numpy.all(a==key) and a._parent_index_==key._parent_index_:
return super(ParamDict, self).__setitem__(a, value)
defaultdict.__setitem__(self, key, value)
2013-12-07 18:45:24 +00:00
class SetDict(ParamDict):
def default_factory(self):
return set()
class IntArrayDict(ParamDict):
def default_factory(self):
return numpy.int_([])
2013-10-16 21:07:54 +01:00
class ParameterIndexOperations(object):
'''
Index operations for storing param index _properties
2013-10-11 16:46:47 +01:00
This class enables index with slices retrieved from object.__getitem__ calls.
Adding an index will add the selected indexes by the slice of an indexarray
indexing a shape shaped array to the flattened index array. Remove will
remove the selected slice indices from the flattened array.
You can give an offset to set an offset for the given indices in the
index array, for multi-param handling.
2013-10-11 16:46:47 +01:00
'''
def __init__(self, constraints=None):
2014-02-12 17:11:55 +00:00
self._properties = IntArrayDict()
if constraints is not None:
for t, i in constraints.iteritems():
self.add(t, i)
2013-10-16 21:07:54 +01:00
def __getstate__(self):
2013-12-07 18:45:24 +00:00
return self._properties#, self._reverse
2013-10-11 16:46:47 +01:00
2013-10-16 21:07:54 +01:00
def __setstate__(self, state):
self._properties = state[0]
2013-12-07 18:45:24 +00:00
# self._reverse = state[1]
2013-10-16 21:07:54 +01:00
2013-10-02 19:13:37 +01:00
def iteritems(self):
return self._properties.iteritems()
2013-10-02 19:13:37 +01:00
2014-02-12 17:11:55 +00:00
def items(self):
return self._properties.items()
def properties(self):
2013-10-02 19:13:37 +01:00
return self._properties.keys()
2014-02-12 17:11:55 +00:00
def iterproperties(self):
return self._properties.iterkeys()
def shift(self, start, size):
for ind in self.iterindices():
toshift = ind>=start
2014-02-12 17:11:55 +00:00
if toshift.size > 0:
ind[toshift] += size
def clear(self):
self._properties.clear()
def size(self):
2013-10-11 16:46:47 +01:00
return reduce(lambda a,b: a+b.size, self.iterindices(), 0)
def iterindices(self):
2013-10-11 16:46:47 +01:00
return self._properties.itervalues()
2013-10-11 16:46:47 +01:00
def indices(self):
return self._properties.values()
2013-10-16 21:07:54 +01:00
def properties_for(self, index):
2014-02-12 17:11:55 +00:00
return vectorize(lambda i: [prop for prop in self.iterproperties() if i in self[prop]], otypes=[list])(index)
2013-10-11 16:46:47 +01:00
def add(self, prop, indices):
2013-10-16 21:07:54 +01:00
try:
self._properties[prop] = combine_indices(self._properties[prop], indices)
2013-10-16 21:07:54 +01:00
except KeyError:
self._properties[prop] = indices
def remove(self, prop, indices):
2013-10-02 19:13:37 +01:00
if prop in self._properties:
diff = remove_indices(self[prop], indices)
removed = numpy.intersect1d(self[prop], indices, True)
2013-10-16 21:07:54 +01:00
if not index_empty(diff):
self._properties[prop] = diff
else:
del self._properties[prop]
return removed.astype(int)
2013-10-11 16:46:47 +01:00
return numpy.array([]).astype(int)
2014-02-12 17:11:55 +00:00
2013-10-11 16:46:47 +01:00
def __getitem__(self, prop):
return self._properties[prop]
2014-02-12 17:11:55 +00:00
def __str__(self, *args, **kwargs):
import pprint
return pprint.pformat(dict(self._properties))
2013-10-11 16:46:47 +01:00
def combine_indices(arr1, arr2):
return numpy.union1d(arr1, arr2)
def remove_indices(arr, to_remove):
return numpy.setdiff1d(arr, to_remove, True)
def index_empty(index):
return numpy.size(index) == 0
2014-02-12 17:11:55 +00:00
class ParameterIndexOperationsView(object):
def __init__(self, param_index_operations, offset, size):
self._param_index_ops = param_index_operations
self._offset = offset
self._size = size
def __getstate__(self):
return [self._param_index_ops, self._offset, self._size]
def __setstate__(self, state):
self._param_index_ops = state[0]
self._offset = state[1]
self._size = state[2]
def _filter_index(self, ind):
return ind[(ind >= self._offset) * (ind < (self._offset + self._size))] - self._offset
def iteritems(self):
for i, ind in self._param_index_ops.iteritems():
ind2 = self._filter_index(ind)
if ind2.size > 0:
yield i, ind2
def items(self):
return [[i,v] for i,v in self.iteritems()]
def properties(self):
return [i for i in self.iterproperties()]
def iterproperties(self):
for i, _ in self.iteritems():
yield i
def shift(self, start, size):
raise NotImplementedError, 'Shifting only supported in original ParamIndexOperations'
def clear(self):
for i, ind in self.items():
self._param_index_ops.remove(i, ind+self._offset)
2013-10-11 16:46:47 +01:00
2014-02-12 17:11:55 +00:00
def size(self):
return reduce(lambda a,b: a+b.size, self.iterindices(), 0)
def iterindices(self):
for _, ind in self.iteritems():
yield ind
def indices(self):
return [ind for ind in self.iterindices()]
2014-02-12 17:11:55 +00:00
def properties_for(self, index):
return vectorize(lambda i: [prop for prop in self.iterproperties() if i in self[prop]], otypes=[list])(index)
def add(self, prop, indices):
self._param_index_ops.add(prop, indices+self._offset)
def remove(self, prop, indices):
removed = self._param_index_ops.remove(prop, indices+self._offset)
if removed.size > 0:
return removed - self._size + 1
if self[prop].size == 0:
del self[prop]
2014-02-12 17:11:55 +00:00
return removed
def __getitem__(self, prop):
ind = self._filter_index(self._param_index_ops[prop])
if ind.size > 0:
return ind
raise KeyError, prop
def __str__(self, *args, **kwargs):
import pprint
return pprint.pformat(dict(self.iteritems()))
def update(self, parameter_index_view):
for i, v in parameter_index_view.iteritems():
self.add(i, v)
def copy(self):
return ParameterIndexOperations(dict(self.iteritems()))
2014-02-12 17:11:55 +00:00
pass