subclassing ndarray almost functional

This commit is contained in:
Max Zwiessele 2013-10-04 10:59:11 +01:00
parent 2ced667193
commit 738f6a5928

View file

@ -48,6 +48,9 @@ class Parameters(object):
return [x.name for x in self._params] return [x.name for x in self._params]
@property @property
def size(self):
return sum(self.sizes)
@property
def sizes(self): def sizes(self):
return [x.size for x in self._params] return [x.size for x in self._params]
@property @property
@ -76,32 +79,20 @@ class Parameters(object):
pass pass
class Parameter(numpy.ndarray): class Param(numpy.ndarray):
tied_to = [] # list of parameters this parameter is tied to tied_to = [] # list of parameters this parameter is tied to
fixed = False # if this parameter is fixed fixed = False # if this parameter is fixed
__array_priority__ = 3.0 __array_priority__ = -1
def __new__(cls, name, input_array, info=None): def __new__(cls, name, input_array, constraints=None):
obj = numpy.array(input_array).view(cls) obj = numpy.array(input_array).view(cls)
obj.name = name obj.name = name
obj._current_slice = slice(None) obj._current_slice = slice(None)
obj._realshape = input_array.shape obj._realshape = input_array.shape
# def attribute_func(value, name): if constraints is None:
# value_func = self.value.__getattribute__(name)
# def f(*args, **kwargs):
# with self.slicing():
# raise AttributeError("This is a parameter view, use self.value for array view")
# try:
# f.__doc__ = value_func.__doc__
# except AttributeError:
# # no docstring present
# pass
# return f
#
# for name in dir(value):
# if not hasattr(self, name):
# self.__setattr__(name, attribute_func(value, name))#value.__getattribute__(name))
obj.constraints = ParameterIndexOperations(obj) obj.constraints = ParameterIndexOperations(obj)
else:
obj.constraints = constraints
return obj return obj
def __array_finalize__(self, obj): def __array_finalize__(self, obj):
@ -115,51 +106,10 @@ class Parameter(numpy.ndarray):
def __array_wrap__(self, out_arr, context=None): def __array_wrap__(self, out_arr, context=None):
return numpy.ndarray.__array_wrap__(self, out_arr, context) return numpy.ndarray.__array_wrap__(self, out_arr, context)
# def __init__(self, name, value, constraint=None, *args, **kwargs):
# self.constraints = ParameterIndexOperations(self)
#
# self._value = value
# self._current_slice = slice(None)
# def attribute_func(value, name):
# value_func = self.value.__getattribute__(name)
# def f(*args, **kwargs):
# with self.slicing():
# raise AttributeError("This is a parameter view, use self.value for array view")
# try:
# f.__doc__ = value_func.__doc__
# except AttributeError:
# # no docstring present
# pass
# return f
#
# for name in dir(value):
# if not hasattr(self, name):
# self.__setattr__(name, attribute_func(value, name))#value.__getattribute__(name))
# super(Parameter, self).__init__(value, *args, **kwargs)
@property @property
def value(self): def value(self):
return self#self.base[self._current_slice] return self#self.base[self._current_slice]
# @value.setter
# def value(self, value):
# self.base[self._current_slice] = value
# @property
# def value(self):
# return self._value[self._current_slice]
# @value.setter
# def value(self, value):
# self._value[self._current_slice] = value
# @property
# def size(self):
# return self.value.size
# @property
# def shape(self):
# return self.value.shape
# @property
# def realshape(self):
# return self._value.shape
@property @property
def _desc(self): def _desc(self):
if self.size <= 1: if self.size <= 1:
@ -171,19 +121,19 @@ class Parameter(numpy.ndarray):
return ' '.join([str(c) if c else '' for c in self.constraints.keys()]) return ' '.join([str(c) if c else '' for c in self.constraints.keys()])
def _set_params(self, param): def _set_params(self, param):
with self.slicing(): #with self.slicing():
self.value.flat = param self.value.flat = param
def _get_params(self): def _get_params(self):
with self.slicing(): #with self.slicing():
return self.value.flat return self.value.flat
def _get_params_transformed(self): def _get_params_transformed(self):
with self.slicing(): #with self.slicing():
params = self.value.copy() params = self.value.copy()
def constrain(self, constraint): def constrain(self, constraint):
with self.slicing(): #with self.slicing():
self.constraints.add(constraint, self._current_slice) self.constraints.add(constraint, self._current_slice)
def constrain_positive(self): def constrain_positive(self):
@ -193,7 +143,7 @@ class Parameter(numpy.ndarray):
self.constrain(NegativeLogexp()) self.constrain(NegativeLogexp())
def unconstrain(self, constraints=None): def unconstrain(self, constraints=None):
with self.slicing(): #with self.slicing():
if constraints is None: if constraints is None:
constraints = self.constraints.keys() constraints = self.constraints.keys()
elif not isinstance(constraints, (tuple, list, numpy.ndarray)): elif not isinstance(constraints, (tuple, list, numpy.ndarray)):
@ -204,37 +154,18 @@ class Parameter(numpy.ndarray):
def unconstrain_positive(self): def unconstrain_positive(self):
self.unconstrain(Logexp()) self.unconstrain(Logexp())
def __getitem__(self, s, *args, **kwargs): def __getitem__(self, s, *args, **kwargs):
#self._current_slice = s
# import ipdb;ipdb.set_trace()
new_arr = numpy.ndarray.__getitem__(self, s, *args, **kwargs) new_arr = numpy.ndarray.__getitem__(self, s, *args, **kwargs)
new_arr = new_arr.view(self.__class__)
try: try:
new_arr._current_slice = s new_arr._current_slice = s
except AttributeError: except AttributeError:
# returning 0d array or float, double etc:
pass pass
return new_arr return new_arr
# def __getitem__(self, s):
# try:
# self.value[s]
# self._current_slice = s#[s if s else slice(s2) for s,s2 in itertools.izip_longest([s], self.shape, fillvalue=None)]
# return self
# except IndexError as i:
# self._current_slice = slice(None)
# raise i
#
# def __setitem__(self, s, value):
# try:
# self.value[s] = value
# self._current_slice = slice(None)
# return self
# except IndexError as i:
# raise i
#
#
def __repr__(self, *args, **kwargs): def __repr__(self, *args, **kwargs):
view = str(self.value) return super(Param, self).__repr__(*args, **kwargs)
view = repr(self.value)
return view return view
def __str__(self, format_spec=None): def __str__(self, format_spec=None):
@ -256,21 +187,12 @@ class Parameter(numpy.ndarray):
return "\n".join([header]+[" {i:^{3}s} | {x: >{1}.{2}G} | {c:^{0}s}".format(lc,lx,p,li, x=x, c=constr.next(), i=i) for i,x in itertools.izip(ind,self.value.flat)]) return "\n".join([header]+[" {i:^{3}s} | {x: >{1}.{2}G} | {c:^{0}s}".format(lc,lx,p,li, x=x, c=constr.next(), i=i) for i,x in itertools.izip(ind,self.value.flat)])
return format_spec.format(self=self) return format_spec.format(self=self)
import contextlib
@contextlib.contextmanager
def slicing(self, *args, **kwargs):
try:
yield
finally:
self._current_slice = slice(None)
del contextlib
if __name__ == '__main__': if __name__ == '__main__':
X = numpy.random.randn(2,4) X = numpy.random.randn(2,4)
p = Parameter("X", X) p = Param("X", X)
p2 = Parameter("Y", numpy.random.randn(3,1)) p2 = Param("Y", numpy.random.randn(3,1))
p3 = Parameter("rbf_variance", numpy.random.rand(1)) p3 = Param("rbf_variance", numpy.random.rand(1))
p4 = Parameter("rbf_lengthscale", numpy.random.rand(2)) p4 = Param("rbf_lengthscale", numpy.random.rand(2))
params = Parameters([p,p2,p3,p4]) params = Parameters([p,p2,p3,p4])
params.X[1].constrain_positive() params.X[1].constrain_positive()
print params.X print params.X