merged and updated slicing operations

This commit is contained in:
Max Zwiessele 2014-03-27 09:08:28 +00:00
commit c65a1e3544
11 changed files with 98 additions and 53 deletions

View file

@ -5,30 +5,27 @@ Created on 11 Mar 2014
'''
from ...core.parameterization.parameterized import ParametersChangedMeta
import numpy as np
import functools
def put_clean(dct, name, *args, **kw):
if name in dct:
dct['_clean_{}'.format(name)] = dct[name]
dct[name] = _slice_wrapper(None, dct[name], *args, **kw)
class KernCallsViaSlicerMeta(ParametersChangedMeta):
def __call__(self, *args, **kw):
instance = super(ParametersChangedMeta, self).__call__(*args, **kw)
instance.K = _Slice_wrapper(instance, instance.K)
instance.Kdiag = _Slice_wrapper_diag(instance, instance.Kdiag)
instance.update_gradients_full = _Slice_wrapper_derivative(instance, instance.update_gradients_full)
instance.update_gradients_diag = _Slice_wrapper_diag_derivative(instance, instance.update_gradients_diag)
instance.gradients_X = _Slice_wrapper_grad_X(instance, instance.gradients_X)
instance.gradients_X_diag = _Slice_wrapper_grad_X_diag(instance, instance.gradients_X_diag)
instance.psi0 = _Slice_wrapper(instance, instance.psi0)
instance.psi1 = _Slice_wrapper(instance, instance.psi1)
instance.psi2 = _Slice_wrapper(instance, instance.psi2)
instance.update_gradients_expectations = _Slice_wrapper_psi_stat_derivative_no_ret(instance, instance.update_gradients_expectations)
instance.gradients_Z_expectations = _Slice_wrapper_psi_stat_derivative_Z(instance, instance.gradients_Z_expectations)
instance.gradients_qX_expectations = _Slice_wrapper_psi_stat_derivative(instance, instance.gradients_qX_expectations)
instance.parameters_changed()
return instance
def __new__(cls, name, bases, dct):
put_clean(dct, 'K')
put_clean(dct, 'Kdiag', diag=True)
put_clean(dct, 'update_gradients_full', diag=False, derivative=True)
put_clean(dct, 'gradients_X', diag=False, derivative=True, ret_X=True)
put_clean(dct, 'gradients_X_diag', diag=True, derivative=True, ret_X=True)
put_clean(dct, 'psi0', diag=False, derivative=False)
put_clean(dct, 'psi1', diag=False, derivative=False)
put_clean(dct, 'psi2', diag=False, derivative=False)
put_clean(dct, 'update_gradients_expectations', derivative=True, psi_stat=True)
put_clean(dct, 'gradients_Z_expectations', derivative=True, psi_stat_Z=True, ret_X=True)
put_clean(dct, 'gradients_qX_expectations', derivative=True, psi_stat=True, ret_X=True)
return super(KernCallsViaSlicerMeta, cls).__new__(cls, name, bases, dct)
class _Slice_wrap(object):
def __init__(self, instance, f):
self.k = instance