mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-12 13:32:39 +02:00
added index_operations and deleted them from paramter
This commit is contained in:
parent
9b0b63dd4d
commit
e8437f7ec3
2 changed files with 52 additions and 15 deletions
50
GPy/core/index_operations.py
Normal file
50
GPy/core/index_operations.py
Normal file
|
|
@ -0,0 +1,50 @@
|
||||||
|
'''
|
||||||
|
Created on Oct 2, 2013
|
||||||
|
|
||||||
|
@author: maxzwiessele
|
||||||
|
'''
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
class ParameterIndexOperations(object):
|
||||||
|
'''
|
||||||
|
Index operations for storing parameter index restrictions
|
||||||
|
This class enables indexing with slices retrieved from object.__getitem__ calls.
|
||||||
|
|
||||||
|
:param shape: shape of parameter, handled by this index restriction class
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, shape):
|
||||||
|
self.restrictions = {}
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
def get_restriction_indices(self, restriction):
|
||||||
|
"""
|
||||||
|
get indices for restriction restriction.
|
||||||
|
these indices can be used as X[indices], which will be a flattened array of
|
||||||
|
all restricted elements
|
||||||
|
"""
|
||||||
|
return numpy.unravel_index(self.restrictions[restriction], self.shape)
|
||||||
|
|
||||||
|
def add_restriction(self, restriction, indices):
|
||||||
|
ind = self._create_raveled_indices(indices)
|
||||||
|
if restriction in self.restrictions:
|
||||||
|
self.restrictions[restriction] = numpy.union1d(self.restrictions[restriction], ind)
|
||||||
|
else:
|
||||||
|
self.restrictions[restriction] = ind
|
||||||
|
|
||||||
|
def remove_restriction(self, restriction, indices):
|
||||||
|
if restriction in self.restrictions:
|
||||||
|
ind = self._create_raveled_indices(indices)
|
||||||
|
diff = numpy.setdiff1d(self.restrictions[restriction], ind, True)
|
||||||
|
if numpy.size(diff):
|
||||||
|
self.restrictions[restriction] = diff
|
||||||
|
else:
|
||||||
|
del self.restrictions[restriction]
|
||||||
|
|
||||||
|
def _create_raveled_indices(self, indices):
|
||||||
|
if isinstance(indices, (tuple, list)):
|
||||||
|
i = [slice(None)] + list(indices)
|
||||||
|
else:
|
||||||
|
i = [slice(None), indices]
|
||||||
|
return numpy.array(numpy.ravel_multi_index(numpy.indices(self.shape)[i], self.shape)).flatten()
|
||||||
|
|
@ -7,6 +7,7 @@ import re
|
||||||
import itertools
|
import itertools
|
||||||
import numpy
|
import numpy
|
||||||
from GPy.core.transformations import Logexp
|
from GPy.core.transformations import Logexp
|
||||||
|
from GPy.core.index_operations import ParameterIndexOperations
|
||||||
_index_re = re.compile('(?:_(\d+))+') # pattern match for indices
|
_index_re = re.compile('(?:_(\d+))+') # pattern match for indices
|
||||||
|
|
||||||
def translate_param_names_to_parameters(param_names):
|
def translate_param_names_to_parameters(param_names):
|
||||||
|
|
@ -73,17 +74,6 @@ class Parameters(object):
|
||||||
return '\n'.join([x.__str__(format_spec=format_spec) for x in self._params])
|
return '\n'.join([x.__str__(format_spec=format_spec) for x in self._params])
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class ParameterIndexing(object):
|
|
||||||
def __init__(self, corresponding_param):
|
|
||||||
self.properties = {}
|
|
||||||
self.param = corresponding_param
|
|
||||||
def add(self, prop, s):
|
|
||||||
if prop in self.properties.keys():
|
|
||||||
self.properties[prop] = self.combine_indices(self.properties[prop], s)
|
|
||||||
else:
|
|
||||||
self.properties[prop] = [numpy.r_[st] for st in s]
|
|
||||||
def combine_indices(self, s1, s2):
|
|
||||||
return [numpy.union1d(numpy.r_[ar1], numpy.r_[ar2]) for ar1, ar2 in itertools.izip_longest(s1, s2)]
|
|
||||||
|
|
||||||
class Parameter(object):
|
class Parameter(object):
|
||||||
tied_to = [] # list of parameters this parameter is tied to
|
tied_to = [] # list of parameters this parameter is tied to
|
||||||
|
|
@ -91,7 +81,7 @@ class Parameter(object):
|
||||||
|
|
||||||
def __init__(self, name, value, constraint=None, *args, **kwargs):
|
def __init__(self, name, value, constraint=None, *args, **kwargs):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.constraints = ParameterIndexing(self)
|
self.constraints = ParameterIndexOperations(self)
|
||||||
self._value = value
|
self._value = value
|
||||||
self._current_slice = slice(None)
|
self._current_slice = slice(None)
|
||||||
|
|
||||||
|
|
@ -129,11 +119,8 @@ class Parameter(object):
|
||||||
|
|
||||||
def _get_params_transformed(self):
|
def _get_params_transformed(self):
|
||||||
params = self.value.copy()
|
params = self.value.copy()
|
||||||
import ipdb;ipdb.set_trace()
|
|
||||||
return
|
|
||||||
|
|
||||||
def constrain_positive(self):
|
def constrain_positive(self):
|
||||||
import ipdb;ipdb.set_trace()
|
|
||||||
self.constraints.add(Logexp(), self._current_slice)
|
self.constraints.add(Logexp(), self._current_slice)
|
||||||
self._current_slice = slice(None)
|
self._current_slice = slice(None)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue