[vardtc missing data] can handle non broadcastable selections

This commit is contained in:
Max Zwiessele 2014-05-20 14:45:34 +01:00
parent 8c8d06c8ae
commit 99699e9e02
2 changed files with 6 additions and 2 deletions

View file

@ -283,7 +283,11 @@ class VarDTCMissingData(LatentFunctionInference):
else: beta = beta_all
VVT_factor = (beta*y)
VVT_factor_all[v, ind].flat = VVT_factor.flat
try:
VVT_factor_all[v, ind].flat = VVT_factor.flat
except ValueError:
mult = np.ravel_multi_index((v.nonzero()[0][:,None],ind[None,:]), VVT_factor_all.shape)
VVT_factor_all.flat[mult] = VVT_factor
output_dim = y.shape[1]
psi0 = psi0_all[v]