[logging] more on logging

This commit is contained in:
mzwiessele 2014-07-02 10:13:30 -07:00
parent eb9fb180fb
commit a9443417d7
3 changed files with 19 additions and 8 deletions

View file

@ -9,7 +9,7 @@ import numpy as np
from ...util.misc import param_to_array
from . import LatentFunctionInference
log_2_pi = np.log(2*np.pi)
import logging
import logging, itertools
logger = logging.getLogger('vardtc')
class VarDTC(LatentFunctionInference):
@ -228,18 +228,28 @@ class VarDTCMissingData(LatentFunctionInference):
self._subarray_indices = []
csa = common_subarrays(inan, 1)
size = len(csa)
next_ten = 0
for i, (v,ind) in enumerate(csa.iteritems()):
if not np.all(v):
logger.info('preparing subarrays {:3.3%}'.format((i+1.)/size))
if ((i+1.)/size) >= next_ten:
logger.info('preparing subarrays {:3%}'.format((i+1.)/size))
next_ten += max(.1, 1./size)
v = ~np.array(v, dtype=bool)
ind = np.array(ind, dtype=int)
if ind.size == Y.shape[1]:
ind = slice(None)
self._subarray_indices.append([v,ind])
logger.info('preparing subarrays Y')
#logger.info('preparing subarrays {:3.3%}'.format((i+1.)/size))
#Ys = [Y[v, :][:, ind] for v, ind in self._subarray_indices]
logger.info('preparing traces Y')
next_ten = [0.]
count = itertools.count()
def trace(y, v, ind):
i = count.next()
if ((i+1.)/size) >= next_ten[0]:
logger.info('preparing traces {:3%}'.format((i+1.)/size))
next_ten[0] += .1
y = y[v,:][:,ind]
return np.einsum('ij,ij->', y,y)
traces = [trace(Y, v, ind) for v, ind in self._subarray_indices]