added index_operations and deleted them from paramter

This commit is contained in:
Max Zwießele 2013-10-02 12:38:07 +01:00
parent 9b0b63dd4d
commit e8437f7ec3
2 changed files with 52 additions and 15 deletions

View file

@ -0,0 +1,50 @@
'''
Created on Oct 2, 2013
@author: maxzwiessele
'''
import numpy
class ParameterIndexOperations(object):
'''
Index operations for storing parameter index restrictions
This class enables indexing with slices retrieved from object.__getitem__ calls.
:param shape: shape of parameter, handled by this index restriction class
'''
def __init__(self, shape):
self.restrictions = {}
self.shape = shape
def get_restriction_indices(self, restriction):
"""
get indices for restriction restriction.
these indices can be used as X[indices], which will be a flattened array of
all restricted elements
"""
return numpy.unravel_index(self.restrictions[restriction], self.shape)
def add_restriction(self, restriction, indices):
ind = self._create_raveled_indices(indices)
if restriction in self.restrictions:
self.restrictions[restriction] = numpy.union1d(self.restrictions[restriction], ind)
else:
self.restrictions[restriction] = ind
def remove_restriction(self, restriction, indices):
if restriction in self.restrictions:
ind = self._create_raveled_indices(indices)
diff = numpy.setdiff1d(self.restrictions[restriction], ind, True)
if numpy.size(diff):
self.restrictions[restriction] = diff
else:
del self.restrictions[restriction]
def _create_raveled_indices(self, indices):
if isinstance(indices, (tuple, list)):
i = [slice(None)] + list(indices)
else:
i = [slice(None), indices]
return numpy.array(numpy.ravel_multi_index(numpy.indices(self.shape)[i], self.shape)).flatten()

View file

@ -7,6 +7,7 @@ import re
import itertools import itertools
import numpy import numpy
from GPy.core.transformations import Logexp from GPy.core.transformations import Logexp
from GPy.core.index_operations import ParameterIndexOperations
_index_re = re.compile('(?:_(\d+))+') # pattern match for indices _index_re = re.compile('(?:_(\d+))+') # pattern match for indices
def translate_param_names_to_parameters(param_names): def translate_param_names_to_parameters(param_names):
@ -73,17 +74,6 @@ class Parameters(object):
return '\n'.join([x.__str__(format_spec=format_spec) for x in self._params]) return '\n'.join([x.__str__(format_spec=format_spec) for x in self._params])
pass pass
class ParameterIndexing(object):
def __init__(self, corresponding_param):
self.properties = {}
self.param = corresponding_param
def add(self, prop, s):
if prop in self.properties.keys():
self.properties[prop] = self.combine_indices(self.properties[prop], s)
else:
self.properties[prop] = [numpy.r_[st] for st in s]
def combine_indices(self, s1, s2):
return [numpy.union1d(numpy.r_[ar1], numpy.r_[ar2]) for ar1, ar2 in itertools.izip_longest(s1, s2)]
class Parameter(object): class Parameter(object):
tied_to = [] # list of parameters this parameter is tied to tied_to = [] # list of parameters this parameter is tied to
@ -91,7 +81,7 @@ class Parameter(object):
def __init__(self, name, value, constraint=None, *args, **kwargs): def __init__(self, name, value, constraint=None, *args, **kwargs):
self.name = name self.name = name
self.constraints = ParameterIndexing(self) self.constraints = ParameterIndexOperations(self)
self._value = value self._value = value
self._current_slice = slice(None) self._current_slice = slice(None)
@ -129,11 +119,8 @@ class Parameter(object):
def _get_params_transformed(self): def _get_params_transformed(self):
params = self.value.copy() params = self.value.copy()
import ipdb;ipdb.set_trace()
return
def constrain_positive(self): def constrain_positive(self):
import ipdb;ipdb.set_trace()
self.constraints.add(Logexp(), self._current_slice) self.constraints.add(Logexp(), self._current_slice)
self._current_slice = slice(None) self._current_slice = slice(None)