ndarray subclass cleaned up

This commit is contained in:
Max Zwiessele 2013-10-04 11:01:15 +01:00
parent 738f6a5928
commit 89c0d7e8eb

View file

@ -102,9 +102,6 @@ class Param(numpy.ndarray):
self._realshape = getattr(obj, '_realshape', None)
self.constraints = getattr(obj, 'constraints', None)
self._current_slice = getattr(obj, '_current_slice', None)
def __array_wrap__(self, out_arr, context=None):
return numpy.ndarray.__array_wrap__(self, out_arr, context)
@property
def value(self):
@ -121,20 +118,16 @@ class Param(numpy.ndarray):
return ' '.join([str(c) if c else '' for c in self.constraints.keys()])
def _set_params(self, param):
#with self.slicing():
self.value.flat = param
self.value.flat = param
def _get_params(self):
#with self.slicing():
return self.value.flat
return self.value.flat
def _get_params_transformed(self):
#with self.slicing():
params = self.value.copy()
params = self.value.copy()
def constrain(self, constraint):
#with self.slicing():
self.constraints.add(constraint, self._current_slice)
self.constraints.add(constraint, self._current_slice)
def constrain_positive(self):
self.constrain(Logexp())
@ -143,13 +136,12 @@ class Param(numpy.ndarray):
self.constrain(NegativeLogexp())
def unconstrain(self, constraints=None):
#with self.slicing():
if constraints is None:
constraints = self.constraints.keys()
elif not isinstance(constraints, (tuple, list, numpy.ndarray)):
constraints = [constraints]
for constr in constraints:
self.constraints.remove(constr, self._current_slice)
if constraints is None:
constraints = self.constraints.keys()
elif not isinstance(constraints, (tuple, list, numpy.ndarray)):
constraints = [constraints]
for constr in constraints:
self.constraints.remove(constr, self._current_slice)
def unconstrain_positive(self):
self.unconstrain(Logexp())
@ -169,23 +161,22 @@ class Param(numpy.ndarray):
return view
def __str__(self, format_spec=None):
#with self.slicing():
if format_spec is None:
constr_matrix = numpy.empty(self._realshape, dtype=object)
constr_matrix[:] = ''
for constr, indices in self.constraints.iteritems():
constr_matrix[indices] = numpy.vectorize(lambda x: " ".join([x,str(constr)]) if x else str(constr))(constr_matrix[indices])
constr_matrix = constr_matrix.astype(numpy.string_)[self._current_slice]
p = numpy.get_printoptions()['precision']
constr = constr_matrix.flat
ind = numpy.array(list(itertools.product(*itertools.imap(range, self._realshape))))[self.constraints.create_raveled_indices(self._current_slice),...]
c_name, x_name, i_name = "Constraint", "Value", "Index"
lc = max(reduce(lambda a,b: max(a, len(b)), constr_matrix.flat, 0), len(c_name))
lx = max(reduce(lambda a,b: max(a, len("{x:=.{0}G}".format(p,x=b))), self.value.flat, 0), len(x_name))
li = max(reduce(lambda a,b: max(a, len(str(b))), ind, 0), len(i_name))
header = " {i:^{3}s} | {x:^{1}s} | {c:^{0}s}".format(lc,lx,p,li, x=x_name, c=c_name, i=i_name)
return "\n".join([header]+[" {i:^{3}s} | {x: >{1}.{2}G} | {c:^{0}s}".format(lc,lx,p,li, x=x, c=constr.next(), i=i) for i,x in itertools.izip(ind,self.value.flat)])
return format_spec.format(self=self)
if format_spec is None:
constr_matrix = numpy.empty(self._realshape, dtype=object)
constr_matrix[:] = ''
for constr, indices in self.constraints.iteritems():
constr_matrix[indices] = numpy.vectorize(lambda x: " ".join([x,str(constr)]) if x else str(constr))(constr_matrix[indices])
constr_matrix = constr_matrix.astype(numpy.string_)[self._current_slice]
p = numpy.get_printoptions()['precision']
constr = constr_matrix.flat
ind = numpy.array(list(itertools.product(*itertools.imap(range, self._realshape))))[self.constraints.create_raveled_indices(self._current_slice),...]
c_name, x_name, i_name = "Constraint", "Value", "Index"
lc = max(reduce(lambda a,b: max(a, len(b)), constr_matrix.flat, 0), len(c_name))
lx = max(reduce(lambda a,b: max(a, len("{x:=.{0}G}".format(p,x=b))), self.value.flat, 0), len(x_name))
li = max(reduce(lambda a,b: max(a, len(str(b))), ind, 0), len(i_name))
header = " {i:^{3}s} | {x:^{1}s} | {c:^{0}s}".format(lc,lx,p,li, x=x_name, c=c_name, i=i_name)
return "\n".join([header]+[" {i:^{3}s} | {x: >{1}.{2}G} | {c:^{0}s}".format(lc,lx,p,li, x=x, c=constr.next(), i=i) for i,x in itertools.izip(ind,self.value.flat)])
return format_spec.format(self=self)
if __name__ == '__main__':
X = numpy.random.randn(2,4)