vardtc updates

This commit is contained in:
Max Zwiessele 2014-03-24 13:32:28 +00:00
parent 8d1cae6459
commit 6b8e418597

View file

@ -179,6 +179,7 @@ class VarDTC(object):
return post, log_marginal, grad_dict
class VarDTCMissingData(object):
const_jitter = 1e-6
def __init__(self, limit=1):
from ...util.caching import Cacher
self._Y = Cacher(self._subarray_computations, limit)
@ -250,7 +251,7 @@ class VarDTCMissingData(object):
for y, trYYT, [v, ind] in itertools.izip(Ys, traces, self._subarray_indices):
if het_noise: beta = beta_all[ind]
else: beta = beta_all[0]
else: beta = beta_all
VVT_factor = (beta*y)
VVT_factor_all[v, ind].flat = VVT_factor.flat
@ -311,7 +312,7 @@ class VarDTCMissingData(object):
het_noise, uncertain_inputs, LB,
_LBi_Lmi_psi1Vf, DBi_plus_BiPBi, Lm, A,
psi0, psi1, beta,
data_fit, num_data, output_dim, trYYT)
data_fit, num_data, output_dim, trYYT, Y)
if full_VVT_factor: woodbury_vector[:, ind] = Cpsi1Vf
else: