redesign the tie framework

This commit is contained in:
Zhenwen Dai 2014-09-03 15:56:33 +01:00
parent 1abb3087ae
commit 2463a954f7
3 changed files with 214 additions and 109 deletions

View file

@ -84,6 +84,7 @@ class Param(Parameterizable, ObsAr):
self._gradient_array_ = getattr(obj, '_gradient_array_', None)
self.constraints = getattr(obj, 'constraints', None)
self.priors = getattr(obj, 'priors', None)
self._tie_ = getattr(obj, '_tie_', None)
@property
def param_array(self):
@ -114,6 +115,16 @@ class Param(Parameterizable, ObsAr):
@gradient.setter
def gradient(self, val):
self._gradient_array_[:] = val
@property
def tie(self):
if getattr(self, '_tie_', None) is None:
self._tie_ = numpy.zeros(self._realshape_, dtype=numpy.uint32)
return self._tie_
@tie.setter
def tie(self, val):
self._tie_[:] = val
#===========================================================================
# Array operations -> done
@ -127,6 +138,7 @@ class Param(Parameterizable, ObsAr):
try:
new_arr._current_slice_ = s
new_arr._gradient_array_ = self.gradient[s]
new_arr._tie_ = self.tie[s]
new_arr._original_ = self.base is new_arr.base
except AttributeError: pass # returning 0d array or float, double etc
return new_arr
@ -237,7 +249,7 @@ class Param(Parameterizable, ObsAr):
def _ties_str(self):
return ['']
def _ties_for(self, ravi):
return [['N/A']]*ravi.size
return [['N/A' if self.tie[i]==0 else str(self.tie[i])] for i in xrange(ravi.size)]
def __repr__(self, *args, **kwargs):
name = "\033[1m{x:s}\033[0;0m:\n".format(
x=self.hierarchy_name())