[variational] posterior object copies adjusted

This commit is contained in:
mzwiessele 2014-05-13 08:35:25 +01:00
parent 4590a05d0f
commit 4f627c904f
2 changed files with 16 additions and 3 deletions

View file

@ -89,6 +89,13 @@ class Param(OptimizationHandlable, ObsAr):
def param_array(self): def param_array(self):
return self return self
@property
def values(self):
"""
Return self as numpy array view
"""
return self.view(np.ndarray)
@property @property
def gradient(self): def gradient(self):
""" """
@ -99,11 +106,11 @@ class Param(OptimizationHandlable, ObsAr):
""" """
if getattr(self, '_gradient_array_', None) is None: if getattr(self, '_gradient_array_', None) is None:
self._gradient_array_ = numpy.empty(self._realshape_, dtype=numpy.float64) self._gradient_array_ = numpy.empty(self._realshape_, dtype=numpy.float64)
return self._gradient_array_[self._current_slice_] return self._gradient_array_#[self._current_slice_]
@gradient.setter @gradient.setter
def gradient(self, val): def gradient(self, val):
self._gradient_array_[self._current_slice_] = val self._gradient_array_[:] = val
#=========================================================================== #===========================================================================
# Array operations -> done # Array operations -> done
@ -114,7 +121,10 @@ class Param(OptimizationHandlable, ObsAr):
#if not reduce(lambda a, b: a or numpy.any(b is Ellipsis), s, False) and len(s) <= self.ndim: #if not reduce(lambda a, b: a or numpy.any(b is Ellipsis), s, False) and len(s) <= self.ndim:
# s += (Ellipsis,) # s += (Ellipsis,)
new_arr = super(Param, self).__getitem__(s, *args, **kwargs) new_arr = super(Param, self).__getitem__(s, *args, **kwargs)
try: new_arr._current_slice_ = s; new_arr._original_ = self.base is new_arr.base try:
new_arr._current_slice_ = s
new_arr._gradient_array_ = self.gradient[s]
new_arr._original_ = self.base is new_arr.base
except AttributeError: pass # returning 0d array or float, double etc except AttributeError: pass # returning 0d array or float, double etc
return new_arr return new_arr

View file

@ -100,6 +100,9 @@ class VariationalPosterior(Parameterized):
n.__dict__.update(dc) n.__dict__.update(dc)
n._parameters_[dc['mean']._parent_index_] = dc['mean'] n._parameters_[dc['mean']._parent_index_] = dc['mean']
n._parameters_[dc['variance']._parent_index_] = dc['variance'] n._parameters_[dc['variance']._parent_index_] = dc['variance']
n._gradient_array_ = None
oversize = self.size - self.mean.size - self.variance.size
n.size = n.mean.size + n.variance.size + oversize
n.ndim = n.mean.ndim n.ndim = n.mean.ndim
n.shape = n.mean.shape n.shape = n.mean.shape
n.num_data = n.mean.shape[0] n.num_data = n.mean.shape[0]