[var dtc missing]

This commit is contained in:
mzwiessele 2014-07-02 12:30:18 -07:00
parent c8da45aa8f
commit 8dacea2c13

View file

@ -229,18 +229,27 @@ class VarDTCMissingData(LatentFunctionInference):
#csa = common_subarrays(inan, 1)
size = Y.shape[1]
#logger.info('preparing subarrays {:3.3%}'.format((i+1.)/size))
#Ys = [Y[v, :][:, ind] for v, ind in self._subarray_indices]
Ys = []
next_ten = [0.]
count = itertools.count()
for v, y in itertools.izip(inan.T, Y.T[:,:,None]):
i = count.next()
if ((i+1.)/size) >= next_ten[0]:
logger.info('preparing subarrays {:>6.1%}'.format((i+1.)/size))
next_ten[0] += .1
Ys.append(y[v,:])
next_ten = [0.]
count = itertools.count()
def trace(y):
i = count.next()
if ((i+1.)/size) >= next_ten[0]:
logger.info('preparing traces {:> 3.1%}'.format((i+1.)/size))
logger.info('preparing traces {:>6.1%}'.format((i+1.)/size))
next_ten[0] += .1
y = y[inan[:,i],i:i+1]
return np.einsum('ij,ij->', y,y)
traces = [trace(Y) for _ in xrange(size)]
return traces
return Ys, traces
else:
self._subarray_indices = [[slice(None),slice(None)]]
return [Y], [(Y**2).sum()]
@ -257,7 +266,7 @@ class VarDTCMissingData(LatentFunctionInference):
psi1_all = kern.K(X, Z)
psi2_all = None
traces = self._Y(Y)
Ys, traces = self._Y(Y)
beta_all = 1./np.fmax(likelihood.gaussian_variance(Y_metadata), 1e-6)
het_noise = beta_all.size != 1
@ -287,9 +296,14 @@ class VarDTCMissingData(LatentFunctionInference):
#logger.info('computing dimension-wise likelihood and derivatives')
#size = len(Ys)
for i, [y, v, trYYT] in enumerate(itertools.izip(Y.T[:,:,None], self._inan.T, traces)):
y = y[v]
#logger.info('{:.3%} dimensions:{}'.format((i+1.)/size, ind))
#size = Y.shape[1]
#next_ten = 0
for i, [y, v, trYYT] in enumerate(itertools.izip(Ys, self._inan.T, traces)):
#if ((i+1.)/size) >= next_ten:
# logger.info('preparing traces {:> 6.1%}'.format((i+1.)/size))
# next_ten += .1
#y = y[v]
if het_noise: beta = beta_all[i]
else: beta = beta_all